Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MobileNetV3 architecture for Segmentation #3276

Merged
merged 12 commits into from
Jan 27, 2021

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Jan 22, 2021

Adding MobileNetV3 models for Semantic Segmentation (resolution 520):

Lite R-ASPP with Dilated MobileNetV3 Large Backbone

Heavily optimized for speed. Good for actual mobile usage.

Weight checkpoint:

PR3276/3rd_training/35354083/model_28.pth

Validate:

python -m torch.distributed.launch --nproc_per_node=2 --use_env train.py --dataset coco\
   --model lraspp_mobilenet_v3_large --test-only --pretrained

Accuracy metrics:

0: global correct: 91.2
0: average row correct: ['94.5', '84.3', '69.5', '72.8', '57.7', '42.0', '77.0', '57.0', '90.4', '36.1', '76.0', '60.8', '81.4', '78.9', '81.0', '87.6', '51.3', '83.9', '62.2', '84.2', '56.1']
0: IoU: ['90.2', '69.2', '57.7', '58.5', '47.8', '35.7', '69.5', '47.1', '79.1', '29.6', '62.6', '34.2', '65.5', '63.4', '70.0', '76.8', '30.1', '61.9', '46.8', '70.6', '49.1']
0: mean IoU: 57.9

Speed Benchmark: 0.3278 sec per image on CPU

DeepLabV3 with Dilated MobileNetV3 Large Backbone

Offers good balance between speed and accuracy, significantly faster than the FCN model with a resnet50 backbone without sacrificing too much accuracy.

Weight checkpoint:

PR3276/3rd_training/35354080/model_28.pth

Validate:

python -m torch.distributed.launch --nproc_per_node=2 --use_env train.py --dataset coco\
   --model deeplabv3_mobilenet_v3_large --test-only --pretrained

Accuracy metrics:

0: global correct: 91.2
0: average row correct: ['93.7', '84.9', '73.6', '74.6', '63.6', '50.6', '80.7', '65.1', '91.3', '42.2', '80.4', '70.6', '82.4', '81.8', '83.7', '88.5', '52.6', '87.8', '65.9', '88.3', '63.3']
0: IoU: ['90.1', '69.7', '58.2', '61.3', '49.7', '37.6', '72.7', '52.7', '79.1', '32.2', '64.6', '36.2', '66.7', '67.4', '70.1', '77.3', '33.1', '67.8', '51.1', '73.3', '54.4']
0: mean IoU: 60.3

Speed Benchmark: 0.5869 sec per image on CPU

@datumbox datumbox changed the title [WIP] Add MobileNetV3 architecture for Segmentation Add MobileNetV3 architecture for Segmentation Jan 27, 2021
@datumbox datumbox requested a review from fmassa January 27, 2021 12:46
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!

I only have a couple of minor (non-blocking) comments. The only thing I would really like to see fixed before merge is to have the correct python -m torch.distributed.launch ... commands for reproducibility.

references/segmentation/README.md Outdated Show resolved Hide resolved
torchvision/models/segmentation/segmentation.py Outdated Show resolved Hide resolved
torchvision/models/segmentation/segmentation.py Outdated Show resolved Hide resolved
@@ -82,7 +84,7 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod

self.block = nn.Sequential(*layers)
self.out_channels = cnf.out_channels
self.is_strided = cnf.stride > 1
self._is_cn = cnf.stride > 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity, what does cn mean in here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's from the C0,C1...C5,Cn names used in Object Detection. I use this feature internally to find out where the downsampling was supposed to happen but it's not always done with strides so I had to rename it. If you have any better name for it, happy to change it. I could not think of any...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explanation. Given that this is private I'm fine with this name

"""
# non-public config parameters
reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this feature used in any of the models? Otherwise we can just remove it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a unique implementation detail from the paper on MobileNetV3 models and it's supposed to produce a further speed optimization on object detection and segmentation. In our training scripts we don't use it because we do transfer learning from ImageNet but if someone really wants to train it from scratch and go smaller I provide a way to do it.

On current master this is public (see reduced_tail param) but here I decide to hide before the release and make it an internal implementation detail for future models. Not quite convinced we will use it but want to provide an implementation very close to the paper.

Personally I would prefer to keep it hidden for now and decide later whether we want this gone. Let me know.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'm ok keeping this private for now and maybe removing it from the future.

@datumbox datumbox merged commit e2db2ed into pytorch:master Jan 27, 2021
@datumbox datumbox deleted the mobilenetv3_segmentation branch January 27, 2021 14:09
facebook-github-bot pushed a commit that referenced this pull request Feb 1, 2021
Summary:
* Making _segm_resnet() generic and reusable.

* Adding fcn and deeplabv3 directly on mobilenetv3 backbone.

* Adding tests for segmentation models.

* Rename is_strided with _is_cn.

* Add dilation support on MobileNetV3 for Segmentation.

* Add Lite R-ASPP with MobileNetV3 backbone.

* Add pretrained model weights.

* Removing model fcn_mobilenet_v3_large.

* Adding docs and imports.

* Fixing typo and readme.

Reviewed By: datumbox

Differential Revision: D26156380

fbshipit-source-id: e62528b52728804a40da79c1311562a7f1c2afbd
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants