Skip to content

Commit

Permalink
add mobilevit imagenet pretrained weights
Browse files Browse the repository at this point in the history
  • Loading branch information
david8862 committed Mar 9, 2022
1 parent 56bcc2e commit adde66a
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 44 deletions.
18 changes: 12 additions & 6 deletions common/backbones/imagenet_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ The validation preprocess script and synset labels file are from [tensorflow inc
1. [train_imagenet.py](https://github.com/david8862/keras-YOLOv3-model-set/blob/master/common/backbones/imagenet_training/train_imagenet.py)
```
# python train_imagenet.py -h
usage: train_imagenet.py [-h] [--model_type MODEL_TYPE]
usage: train_imagenet.py [-h] --model_type
{shufflenet,shufflenet_v2,nanonet,darknet53,cspdarknet53,mobilevit_s,mobilevit_xs,mobilevit_xxs}
[--weights_path WEIGHTS_PATH]
[--train_data_path TRAIN_DATA_PATH]
[--val_data_path VAL_DATA_PATH]
Expand All @@ -49,9 +50,8 @@ usage: train_imagenet.py [-h] [--model_type MODEL_TYPE]
optional arguments:
-h, --help show this help message and exit
--model_type MODEL_TYPE
backbone model type: shufflenet/shufflenet_v2/nanonet/
darknet53/cspdarknet53, default=shufflenet_v2
--model_type {shufflenet,shufflenet_v2,nanonet,darknet53,cspdarknet53,mobilevit_s,mobilevit_xs,mobilevit_xxs}
backbone model type
--weights_path WEIGHTS_PATH
Pretrained model/weights file for fine tune
--train_data_path TRAIN_DATA_PATH
Expand All @@ -63,7 +63,7 @@ optional arguments:
--optimizer {adam,rmsprop,sgd}
optimizer for training (adam/rmsprop/sgd), default=sgd
--learning_rate LEARNING_RATE
Initial learning rate, default=0.05
Initial learning rate, default=0.01
--decay_type {None,cosine,exponential,polynomial,piecewise_constant}
Learning rate decay type, default=None
--label_smoothing LABEL_SMOOTHING
Expand All @@ -89,7 +89,7 @@ For example, following cmd will start training shufflenet_v2 with the Imagenet t
# python train_imagenet.py --model_type=shufflenet_v2 --train_data_path=data/ILSVRC2012_img_train/ --val_data_path=data/ILSVRC2012_img_val/ --batch_size=128 --optimizer=sgd --learning_rate=0.01 --decay_type=cosine --label_smoothing=0.1
```

Currently it support shufflenet/shufflenet_v2/nanonet/darknet53/cspdarknet53 which is implement under [backbones](https://github.com/david8862/keras-YOLOv3-model-set/tree/master/common/backbones) with fixed hyperparam.
Currently it support shufflenet/shufflenet_v2/nanonet/darknet53/cspdarknet53/mobilevit_s/mobilevit_xs/mobilevit_xxs which is implement under [backbones](https://github.com/david8862/keras-YOLOv3-model-set/tree/master/common/backbones) with fixed hyperparam.

Checkpoints during training could be found at logs/. Choose a best one as result

Expand Down Expand Up @@ -145,4 +145,10 @@ MultiGPU usage: use `--gpu_num N` to use N GPUs. It is passed to the [Keras mult
journal = {arXiv},
year={2018}
}
@article{MobileViT,
title={MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer},
author={Sachin Mehta, Mohammad Rastegari},
journal = {arXiv},
year={2021}
}
```
4 changes: 2 additions & 2 deletions common/backbones/imagenet_training/train_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Model definition options
parser.add_argument('--model_type', type=str, required=False, default='shufflenet_v2',
help='backbone model type: shufflenet/shufflenet_v2/nanonet/darknet53/cspdarknet53, default=%(default)s')
parser.add_argument('--model_type', type=str, required=True, choices=['shufflenet', 'shufflenet_v2', 'nanonet', 'darknet53', 'cspdarknet53', 'mobilevit_s', 'mobilevit_xs', 'mobilevit_xxs'],
help='backbone model type')
parser.add_argument('--weights_path', type=str, required=False, default=None,
help = "Pretrained model/weights file for fine tune")

Expand Down
38 changes: 22 additions & 16 deletions common/backbones/mobilevit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
import tensorflow as tf


BASE_WEIGHT_PATH = (
'https://github.com/david8862/tf-keras-image-classifier/'
'releases/download/v1.0.0/')


def preprocess_input(x):
"""
"mode" option description in preprocess_input
Expand Down Expand Up @@ -447,20 +452,20 @@ def MobileViT(channels,

# Load weights.
if weights == 'imagenet':
raise ValueError('No valid ImageNet pretrained weights now.')
#if include_top:
#model_name = ('mobilenet_v2_weights_tf_dim_ordering_tf_kernels_' +
#str(alpha) + '_' + str(rows) + '.h5')
#weight_path = BASE_WEIGHT_PATH + model_name
#weights_path = get_file(
#model_name, weight_path, cache_subdir='models')
#else:
#model_name = ('mobilenet_v2_weights_tf_dim_ordering_tf_kernels_' +
#str(alpha) + '_' + str(rows) + '_no_top' + '.h5')
#weight_path = BASE_WEIGHT_PATH + model_name
#weights_path = get_file(
#model_name, weight_path, cache_subdir='models')
#model.load_weights(weights_path)
if model_type == 's':
raise ValueError('No valid ImageNet pretrained weights for mobilevit_s now.')

if include_top:
model_name = ('mobilevit_' + model_type + '_weights_tf_dim_ordering_tf_kernels_256.h5')
weight_path = BASE_WEIGHT_PATH + model_name
weights_path = get_file(
model_name, weight_path, cache_subdir='models')
else:
model_name = ('mobilevit_' + model_type + '_weights_tf_dim_ordering_tf_kernels_256_no_top.h5')
weight_path = BASE_WEIGHT_PATH + model_name
weights_path = get_file(
model_name, weight_path, cache_subdir='models')
model.load_weights(weights_path)
elif weights is not None:
model.load_weights(weights)

Expand Down Expand Up @@ -514,9 +519,10 @@ def MobileViT_XXS(input_shape=None,

if __name__ == '__main__':
input_tensor = Input(shape=(None, None, 3), name='image_input')
#model = MobileViT_XXS(include_top=True, input_tensor=input_tensor, weights=None)
model = MobileViT_XXS(include_top=True, input_shape=(256, 256, 3), weights=None)
#model = MobileViT_XXS(include_top=True, input_tensor=input_tensor, weights='imagenet')
model = MobileViT_XXS(include_top=True, input_shape=(256, 256, 3), weights='imagenet')
model.summary()
K.set_learning_phase(0)

import numpy as np
from tensorflow.keras.applications.resnet50 import decode_predictions
Expand Down
71 changes: 54 additions & 17 deletions tracking/cpp_inference/yoloSort/kalman_demo/kalman_demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
// Reference doc:
// https://blog.csdn.net/GDFSG/article/details/50904811
// --------------------------------------------------------------------
#include <random>
#include <stdio.h>
#include "opencv2/video/tracking.hpp"
#include "opencv2/highgui/highgui.hpp"

Expand All @@ -26,12 +28,37 @@ void mouseEvent(int event, int x, int y, int flags, void *param)
}


int randomIntGenerate(const int low, const int high)
{
static std::random_device rd;
static std::default_random_engine eng(rd());
//static std::mt19937 eng(rd());

std::uniform_int_distribution<int> dis(low, high);

return dis(eng);
}


float randomFloatGenerate(const float low, const float high)
{
static std::random_device rd;
static std::default_random_engine eng(rd());
//static std::mt19937 eng(rd());

std::uniform_real_distribution<float> dis(low, high);
//std::normal_distribution<float> dis((high+low)/2, (high-low)/2);

return dis(eng);
}

void kalmanTest(void)
{
// define state value (x, y, x', y') number, measurement value (x, y) number,
// here x' is speed of x
int stateNum = 4;
int measureNum = 2;
// here x', y' is speed of x, y
const int stateNum = 4;
const int measureNum = 2;
const float dt = 0.1;

// create kalman filter object, 0 is control value number
cv::KalmanFilter kalman = cv::KalmanFilter(stateNum, measureNum, 0);
Expand All @@ -41,23 +68,23 @@ void kalmanTest(void)

// define state transition matrix, here we assume
// uniform linear motion for both x and y:
// x* = x + t * x' (x* is x for next step; t = 1)
// y* = y + t * y'
// x* = x + dt * x' (x* is x for next step; dt = 0.1s)
// y* = y + dt * y'
// x*' = x'
// y*' = y'
//
// in form of matrix:
//
// / x* \ = / 1 0 1 0 \ / x \
// | y* | = | 0 1 0 1 | | y |
// | x*'| = | 0 0 1 0 | | x'|
// \ y*'/ = \ 0 0 0 1 / \ y'/
// / x* \ = / 1 0 dt 0 \ / x \
// | y* | = | 0 1 0 dt | | y |
// | x*'| = | 0 0 1 0 | | x'|
// \ y*'/ = \ 0 0 0 1 / \ y'/
//
kalman.transitionMatrix = (cv::Mat_<float>(stateNum, stateNum) <<
1, 0, 1, 0,
0, 1, 0, 1,
0, 0, 1, 0,
0, 0, 0, 1);
1, 0, dt, 0,
0, 1, 0, dt,
0, 0, 1, 0,
0, 0, 0, 1);

// initialize measurement matrix with diag(1)
cv::setIdentity(kalman.measurementMatrix, cv::Scalar::all(1));
Expand All @@ -83,10 +110,13 @@ void kalmanTest(void)
cv::Mat prediction = kalman.predict();
cv::Point predictPt = cv::Point(prediction.at<float>(0, 0), prediction.at<float>(1, 0));

// pick measurement value from mouse position
// pick measurement value from mouse position, here we add
// random number to simulate measurement noise
cv::Point statePt = mousePosition;
measurement.at<float>(0, 0) = statePt.x;
measurement.at<float>(1, 0) = statePt.y;
//measurement.at<float>(0, 0) = statePt.x + randomIntGenerate(-5, 5);
//measurement.at<float>(1, 0) = statePt.y + randomIntGenerate(-5, 5);
measurement.at<float>(0, 0) = statePt.x + randomFloatGenerate(-5.0, 5.0);
measurement.at<float>(1, 0) = statePt.y + randomFloatGenerate(-5.0, 5.0);

// update measurement value
kalman.correct(measurement);
Expand All @@ -96,8 +126,15 @@ void kalmanTest(void)
cv::circle(img, predictPt, 8, CV_RGB(0, 255, 0), -1); // predicted point as green
cv::circle(img, statePt, 8, CV_RGB(255, 0, 0), -1); // current position as red

// show predict & current point coordinate
char buf[256];
sprintf(buf, "predicted position:(%3d,%3d)", predictPt.x, predictPt.y);
cv::putText(img, buf, cv::Point(10,30), CV_FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(255,0,0), 1, 8);
sprintf(buf, "current position:(%3d,%3d)", statePt.x, statePt.y);
cv::putText(img, buf, cv::Point(10,60), CV_FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(255,0,0), 1, 8);

cv::imshow("Kalman", img);
char code = (char)cv::waitKey(100);
char code = (char)cv::waitKey(dt*1000);
if (code == 27 || code == 'q' || code == 'Q')
break;
}
Expand Down
2 changes: 1 addition & 1 deletion yolo2/loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding=utf-8 -*-
#!/usr/bin/python3
# -*- coding=utf-8 -*-

import math
import tensorflow as tf
Expand Down
2 changes: 1 addition & 1 deletion yolo3/loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding=utf-8 -*-
#!/usr/bin/python3
# -*- coding=utf-8 -*-

import math
import tensorflow as tf
Expand Down
2 changes: 1 addition & 1 deletion yolo5/loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding=utf-8 -*-
#!/usr/bin/python3
# -*- coding=utf-8 -*-

import math
import tensorflow as tf
Expand Down

0 comments on commit adde66a

Please sign in to comment.