Skip to content

Commit

Permalink
more imagenet training features
Browse files Browse the repository at this point in the history
  • Loading branch information
david8862 committed Jan 5, 2022
1 parent 525a379 commit e9f0f94
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 49 deletions.
22 changes: 15 additions & 7 deletions common/backbones/imagenet_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,14 @@ The validation preprocess script and synset labels file are from [tensorflow inc
```
# python train_imagenet.py -h
usage: train_imagenet.py [-h] [--model_type MODEL_TYPE]
[--weights_path WEIGHTS_PATH]
[--train_data_path TRAIN_DATA_PATH]
[--val_data_path VAL_DATA_PATH]
[--weights_path WEIGHTS_PATH]
[--batch_size BATCH_SIZE] [--optim_type OPTIM_TYPE]
[--batch_size BATCH_SIZE]
[--optimizer {adam,rmsprop,sgd}]
[--learning_rate LEARNING_RATE]
[--decay_type {None,cosine,exponential,polynomial,piecewise_constant}]
[--label_smoothing LABEL_SMOOTHING]
[--init_epoch INIT_EPOCH] [--total_epoch TOTAL_EPOCH]
[--gpu_num GPU_NUM] [--evaluate]
[--verify_with_image] [--dump_headless]
Expand All @@ -49,18 +52,23 @@ optional arguments:
--model_type MODEL_TYPE
backbone model type: shufflenet/shufflenet_v2/nanonet/
darknet53/cspdarknet53, default=shufflenet_v2
--weights_path WEIGHTS_PATH
Pretrained model/weights file for fine tune
--train_data_path TRAIN_DATA_PATH
path to Imagenet train data
--val_data_path VAL_DATA_PATH
path to Imagenet validation dataset
--weights_path WEIGHTS_PATH
Pretrained model/weights file for fine tune
--batch_size BATCH_SIZE
batch size for train, default=128
--optim_type OPTIM_TYPE
optimizer type: sgd/rmsprop/adam, default=sgd
--optimizer {adam,rmsprop,sgd}
optimizer for training (adam/rmsprop/sgd), default=sgd
--learning_rate LEARNING_RATE
Initial learning rate, default=0.05
--decay_type {None,cosine,exponential,polynomial,piecewise_constant}
Learning rate decay type, default=None
--label_smoothing LABEL_SMOOTHING
Label smoothing factor (between 0 and 1) for
classification loss, default=0
--init_epoch INIT_EPOCH
Initial training epochs for fine tune training,
default=0
Expand All @@ -78,7 +86,7 @@ optional arguments:
For example, following cmd will start training shufflenet_v2 with the Imagenet train/val data we prepared before:

```
# python train_imagenet.py --model_type=shufflenet_v2 --train_data_path=data/ILSVRC2012_img_train/ --val_data_path=data/ILSVRC2012_img_val/ --batch_size=64
# python train_imagenet.py --model_type=shufflenet_v2 --train_data_path=data/ILSVRC2012_img_train/ --val_data_path=data/ILSVRC2012_img_val/ --batch_size=128 --optimizer=adam --learning_rate=0.001 --decay_type=cosine --label_smoothing=0.1
```

Currently it support shufflenet/shufflenet_v2/nanonet/darknet53/cspdarknet53 which is implement under [backbones](https://github.com/david8862/keras-YOLOv3-model-set/tree/master/common/backbones) with fixed hyperparam.
Expand Down
59 changes: 35 additions & 24 deletions common/backbones/imagenet_training/train_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from multiprocessing import cpu_count

import tensorflow.keras.backend as K
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau, LearningRateScheduler, TerminateOnNaN
#from tensorflow.keras.utils import multi_gpu_model
Expand All @@ -30,6 +30,11 @@
from yolo3.models.yolo3_darknet import DarkNet53
from yolo4.models.yolo4_darknet import CSPDarkNet53

#from common.utils import optimize_tf_gpu
from common.model_utils import get_optimizer
from common.callbacks import CheckpointCleanCallBack


os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

Expand Down Expand Up @@ -91,18 +96,6 @@ def get_model(model_type, include_top=True):
return model, input_shape[:2]


def get_optimizer(optim_type, learning_rate):
if optim_type == 'sgd':
optimizer = SGD(lr=learning_rate, decay=5e-4, momentum=0.9)
elif optim_type == 'rmsprop':
optimizer = RMSprop(lr=learning_rate)
elif optim_type == 'adam':
optimizer = Adam(lr=learning_rate, decay=5e-4)
else:
raise ValueError('Unsupported optimizer type')
return optimizer


def train(args, model, input_shape, strategy):
log_dir = os.path.join('logs', '000')

Expand All @@ -120,7 +113,7 @@ def train(args, model, input_shape, strategy):
lr_scheduler = LearningRateScheduler(lambda epoch: learn_rates[epoch // 30])
checkpoint_clean = CheckpointCleanCallBack(log_dir, max_val_keep=5)

callbacks=[logging, checkpoint, lr_scheduler, terminate_on_nan, checkpoint_clean])
callbacks=[logging, checkpoint, lr_scheduler, terminate_on_nan, checkpoint_clean]

# data generator
train_datagen = ImageDataGenerator(preprocessing_function=preprocess,
Expand Down Expand Up @@ -175,20 +168,27 @@ def train(args, model, input_shape, strategy):
interpolation='nearest')

# get optimizer
optimizer = get_optimizer(args.optim_type, args.learning_rate)
if args.decay_type:
callbacks.remove(lr_scheduler)
steps_per_epoch = max(1, train_generator.samples//args.batch_size)
decay_steps = steps_per_epoch * (args.total_epoch - args.init_epoch)
optimizer = get_optimizer(args.optimizer, args.learning_rate, average_type=None, decay_type=args.decay_type, decay_steps=decay_steps)

# get loss
losses = CategoricalCrossentropy(label_smoothing=args.label_smoothing)

# model compile
if strategy:
with strategy.scope():
model.compile(
optimizer=optimizer,
metrics=['accuracy', 'top_k_categorical_accuracy'],
loss='categorical_crossentropy')
loss=losses)
else:
model.compile(
optimizer=optimizer,
metrics=['accuracy', 'top_k_categorical_accuracy'],
loss='categorical_crossentropy')
loss=losses)

# start training
print('Train on {} samples, val on {} samples, with batch size {}.'.format(train_generator.samples, test_generator.samples, args.batch_size))
Expand Down Expand Up @@ -218,7 +218,7 @@ def evaluate_model(args, model, input_shape):
batch_size=args.batch_size)

# get optimizer
optimizer = get_optimizer(args.optim_type, args.learning_rate)
optimizer = get_optimizer(args.optimizer, args.learning_rate, average_type=None, decay_type=None)

# start evaluate
model.compile(
Expand Down Expand Up @@ -309,26 +309,37 @@ def main(args):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Model definition options
parser.add_argument('--model_type', type=str, required=False, default='shufflenet_v2',
help='backbone model type: shufflenet/shufflenet_v2/nanonet/darknet53/cspdarknet53, default=%(default)s')
parser.add_argument('--train_data_path', type=str,# required=True,
help='path to Imagenet train data')
parser.add_argument('--val_data_path', type=str,# required=True,
help='path to Imagenet validation dataset')
parser.add_argument('--weights_path', type=str, required=False, default=None,
help = "Pretrained model/weights file for fine tune")

# Data options
parser.add_argument('--train_data_path', type=str, #required=True,
help='path to Imagenet train data')
parser.add_argument('--val_data_path', type=str, #required=True,
help='path to Imagenet validation dataset')

# Training options
parser.add_argument('--batch_size', type=int, required=False, default=128,
help = "batch size for train, default=%(default)s")
parser.add_argument('--optim_type', type=str, required=False, default='sgd', choices=['sgd', 'rmsprop', 'adam'],
help='optimizer type: sgd/rmsprop/adam, default=%(default)s')
parser.add_argument('--optimizer', type=str, required=False, default='sgd', choices=['adam', 'rmsprop', 'sgd'],
help = "optimizer for training (adam/rmsprop/sgd), default=%(default)s")
parser.add_argument('--learning_rate', type=float,required=False, default=.05,
help = "Initial learning rate, default=%(default)s")
parser.add_argument('--decay_type', type=str, required=False, default=None, choices=[None, 'cosine', 'exponential', 'polynomial', 'piecewise_constant'],
help = "Learning rate decay type, default=%(default)s")
parser.add_argument('--label_smoothing', type=float, required=False, default=0,
help = "Label smoothing factor (between 0 and 1) for classification loss, default=%(default)s")
parser.add_argument('--init_epoch', type=int,required=False, default=0,
help = "Initial training epochs for fine tune training, default=%(default)s")
parser.add_argument('--total_epoch', type=int,required=False, default=200,
help = "Total training epochs, default=%(default)s")
parser.add_argument('--gpu_num', type=int, required=False, default=1,
help='Number of GPU to use, default=%(default)s')

# Evaluation options
parser.add_argument('--evaluate', default=False, action="store_true",
help='Evaluate a trained model with validation dataset')
parser.add_argument('--verify_with_image', default=False, action="store_true",
Expand Down
87 changes: 69 additions & 18 deletions common/backbones/mobilevit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@
# A tf.keras implementation of MobileViT,
# ported from https://keras.io/examples/vision/mobilevit/
#
# Reference Paper
# Reference
# [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178)
# https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/mobilevit.py
#
import os, sys
import warnings
import math

from keras_applications.imagenet_utils import _obtain_input_shape
from keras_applications.imagenet_utils import preprocess_input as _preprocess_input
from tensorflow.keras.utils import get_source_inputs, get_file
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, Dense, GlobalAveragePooling2D, GlobalMaxPooling2D, Dropout, ZeroPadding2D
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, Dense, GlobalAveragePooling2D, GlobalMaxPooling2D, Dropout, ZeroPadding2D, Lambda
from tensorflow.keras.layers import Input, BatchNormalization, Add, Reshape, LayerNormalization, MultiHeadAttention, Concatenate, Activation
from tensorflow.keras.activations import swish
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
import tensorflow as tf


def preprocess_input(x):
Expand Down Expand Up @@ -166,7 +169,50 @@ def transformer_block(x, projection_dim, num_heads, dropout, prefix):
return x


def img_resize(x, size, mode='bilinear'):
if mode == 'bilinear':
return tf.image.resize(x, size=size, method='bilinear')
elif mode == 'nearest':
return tf.image.resize(x, size=size, method='nearest')
elif mode == 'bicubic':
return tf.image.resize(x, size=size, method='bicubic')
elif mode == 'area':
return tf.image.resize(x, size=size, method='area')
elif mode == 'gaussian':
return tf.image.resize(x, size=size, method='gaussian')
else:
raise ValueError('invalid resize type {}'.format(mode))


def unfolding(x, patch_h, patch_w, prefix):
batch_size, orig_h, orig_w, in_channels = x.shape

# get tensor width & height aligned with patch size
new_h = int(math.ceil(orig_h / patch_h) * patch_h)
new_w = int(math.ceil(orig_w / patch_w) * patch_w)

if new_h != orig_h or new_w != orig_w:
# resize feature tensor for unfolding
x = Lambda(img_resize,
arguments={'size': (new_h, new_w), 'mode': 'bilinear'},
name=prefix+'unfold_resize')(x)

# number of patches along new width and height
num_patch_w = new_w // patch_w # n_w
num_patch_h = new_h // patch_h # n_h
num_patches = num_patch_h * num_patch_w # N
patch_size = patch_h * patch_w # P

# [new_h, new_w, C] --> [P, N, C]
x = Reshape((patch_size, num_patches, -1),
name=prefix+'unfold')(x)

return x, new_h, new_w


def mobilevit_block(x, num_blocks, num_heads, projection_dim, strides, dropout, block_id):
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
in_channels = x.shape[channel_axis]
prefix = 'mvit_block_{}_'.format(block_id)

# Local projection with convolutions.
Expand All @@ -179,11 +225,9 @@ def mobilevit_block(x, num_blocks, num_heads, projection_dim, strides, dropout,
strides=strides,
name=prefix+'conv2')

patch_size = 4 # 2x2, for the Transformer blocks.
# Unfold into patches and then pass through Transformers.
num_patches = int((local_features.shape[1] * local_features.shape[2]) / patch_size)
non_overlapping_patches = Reshape((patch_size, num_patches, projection_dim),
name=prefix+'unfold')(local_features)
patch_h, patch_w = 2, 2 # 2x2, for the Transformer blocks.
non_overlapping_patches, new_h, new_w = unfolding(local_features, patch_h, patch_w, prefix)

# Transformer blocks
global_features = non_overlapping_patches
Expand All @@ -196,12 +240,19 @@ def mobilevit_block(x, num_blocks, num_heads, projection_dim, strides, dropout,
prefix=name)

# Fold into conv-like feature-maps.
folded_feature_map = Reshape((*local_features.shape[1:-1], projection_dim),
folded_feature_map = Reshape((new_h, new_w, projection_dim),
name=prefix+'fold')(global_features)

# resize back to local feature shape
orig_h, orig_w = local_features.shape[1:-1]
if new_h != orig_h or new_w != orig_w:
folded_feature_map = Lambda(img_resize,
arguments={'size': (orig_h, orig_w), 'mode': 'bilinear'},
name=prefix+'fold_resize')(folded_feature_map)

# Apply point-wise conv -> concatenate with the input features.
folded_feature_map = conv_block(folded_feature_map,
filters=x.shape[-1],
filters=in_channels,
kernel_size=1,
strides=strides,
name=prefix+'conv3')
Expand All @@ -211,7 +262,8 @@ def mobilevit_block(x, num_blocks, num_heads, projection_dim, strides, dropout,

# Fuse the local and global features using a convoluion layer.
local_global_features = conv_block(local_global_features,
filters=projection_dim,
#filters=projection_dim,
filters=in_channels,
strides=strides,
name=prefix+'conv4')
return local_global_features
Expand Down Expand Up @@ -248,7 +300,7 @@ def MobileViT(channels,
# Determine proper input shape
input_shape = _obtain_input_shape(input_shape,
default_size=256,
min_size=64,
min_size=32,
data_format=K.image_data_format(),
require_flatten=include_top,
weights=weights)
Expand All @@ -260,8 +312,8 @@ def MobileViT(channels,
rows = input_shape[row_axis]
cols = input_shape[col_axis]

if rows and cols and (rows < 64 or cols < 64):
raise ValueError('Input size must be at least 64x64; got `input_shape=' +
if rows and cols and (rows < 32 or cols < 32):
raise ValueError('Input size must be at least 32x32; got `input_shape=' +
str(input_shape) + '`')

if input_tensor is None:
Expand All @@ -274,12 +326,11 @@ def MobileViT(channels,
str(input_tensor.shape) + '`')
img_input = input_tensor

assert (rows%64 == 0 and cols%64 == 0), 'input shape should be multiples of 64'

channel_axis = 1 if K.image_data_format() == 'channels_first' else -1

# Transformer block number for each MobileViT block
# Transformer block_number/head_number for each MobileViT block
mvit_blocks = [2, 4, 3]
num_heads=1

# Initial stem-conv -> MV2 block.
x = conv_block(img_input, filters=channels[0], name='stem_conv')
Expand All @@ -298,17 +349,17 @@ def MobileViT(channels,
# First MV2 -> MobileViT block.
x = inverted_residual_block(
x, expanded_channels=channels[3] * expansion, output_channels=channels[4], strides=2, block_id=4)
x = mobilevit_block(x, num_blocks=mvit_blocks[0], num_heads=4, projection_dim=dims[0], strides=1, dropout=0.1, block_id=0)
x = mobilevit_block(x, num_blocks=mvit_blocks[0], num_heads=num_heads, projection_dim=dims[0], strides=1, dropout=0.1, block_id=0)

# Second MV2 -> MobileViT block.
x = inverted_residual_block(
x, expanded_channels=channels[5] * expansion, output_channels=channels[5], strides=2, block_id=5)
x = mobilevit_block(x, num_blocks=mvit_blocks[1], num_heads=4, projection_dim=dims[1], strides=1, dropout=0.1, block_id=1)
x = mobilevit_block(x, num_blocks=mvit_blocks[1], num_heads=num_heads, projection_dim=dims[1], strides=1, dropout=0.1, block_id=1)

# Third MV2 -> MobileViT block.
x = inverted_residual_block(
x, expanded_channels=channels[6] * expansion, output_channels=channels[6], strides=2, block_id=6)
x = mobilevit_block(x, num_blocks=mvit_blocks[2], num_heads=4, projection_dim=dims[2], strides=1, dropout=0.1, block_id=2)
x = mobilevit_block(x, num_blocks=mvit_blocks[2], num_heads=num_heads, projection_dim=dims[2], strides=1, dropout=0.1, block_id=2)

x = conv_block(x, filters=channels[7], kernel_size=1, strides=1, name='1x1_conv')

Expand Down

0 comments on commit e9f0f94

Please sign in to comment.