Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Feature] Add ViT-Adapter Model #2762

Merged
merged 17 commits into from
Mar 17, 2023
Prev Previous commit
Next Next commit
refine for merge 2
  • Loading branch information
juncaipeng committed Nov 25, 2022
commit cc7aa0aec3eb91815dfc4d6b858dc130973dc461
10 changes: 4 additions & 6 deletions configs/vit_adapter/README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# Semantic Flow for Fast and Accurate Scene Parsing
# Vision Transformer Adapter for Dense Predictions

## Reference

> Xiangtai Li, Ansheng You, Zhen Zhu, Houlong Zhao, Maoke Yang, Kuiyuan Yang, Shaohua Tan, Yunhai Tong:
Semantic Flow for Fast and Accurate Scene Parsing. ECCV (1) 2020: 775-793 .
> Chen, Zhe, Yuchen Duan, Wenhai Wang, Junjun He, Tong Lu, Jifeng Dai, and Yu Qiao. "Vision Transformer Adapter for Dense Predictions." arXiv preprint arXiv:2205.08534 (2022).

## Performance

### Cityscapes
### ADE20K

| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links |
|-|-|-|-|-|-|-|-|
|SFNet|ResNet18_OS8|1024x1024|80000|78.72%|79.11%|79.28%|[model](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/sfnet_resnet18_os8_cityscapes_1024x1024_80k/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/sfnet_resnet18_os8_cityscapes_1024x1024_80k/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=0d790ad96282048b136342fcebb08d14)|
|SFNet|ResNet50_OS8|1024x1024|80000|81.49%|81.63%|81.85%|[model](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/sfnet_resnet50_os8_cityscapes_1024x1024_80k/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/sfnet_resnet50_os8_cityscapes_1024x1024_80k/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=d458349ec63ea8ccd6fae84afa8ea981)|
|UPerNetViTAdapter|ViT-Adapter-Tiny|512x512|160000|%|%|%|[model]() \| [log]() \| [vdl]()|
2 changes: 0 additions & 2 deletions paddleseg/core/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def train(model,
reader_cost_averager.record(time.time() - batch_start)
images = data['img']
labels = data['label'].astype('int64')

edges = None
if 'edge' in data.keys():
edges = data['edge'].astype('int64')
Expand Down Expand Up @@ -212,7 +211,6 @@ def train(model,
losses=losses)
loss = sum(loss_list)
loss.backward()

# if the optimizer is ReduceOnPlateau, the loss is the one which has been pass into step.
if isinstance(optimizer, paddle.optimizer.lr.ReduceOnPlateau):
optimizer.step(loss)
Expand Down
82 changes: 28 additions & 54 deletions paddleseg/models/upernet_vit_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,24 @@
@manager.MODELS.add_component
class UPerNetViTAdapter(nn.Layer):
"""
The UPerNet implementation based on PaddlePaddle.
The UPerNetViTAdapter implementation based on PaddlePaddle.

The original article refers to
Tete Xiao, et, al. "Unified Perceptual Parsing for Scene Understanding"
(https://arxiv.org/abs/1807.10221).
Chen, Zhe, Yuchen Duan, Wenhai Wang, Junjun He, Tong Lu, Jifeng Dai, and Yu Qiao.
"Vision Transformer Adapter for Dense Predictions."
(https://arxiv.org/abs/2205.08534).

Args:
num_classes (int): The unique number of target classes.
backbone (Paddle.nn.Layer): Backbone network, currently support Resnet50/101.
backbone_indices (tuple): Four values in the tuple indicate the indices of output of backbone.
channels (int): The channels of inter layers. Default: 512.
aux_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: False.
backbone (nn.Layer): The backbone network.
backbone_indices (tuple | list): The values indicate the indices of output of backbone.
channels (int, optional): The channels of inter layers in upernet head. Default: 512.
pool_scales (list, optional): The scales in PPM. Default: [1, 2, 3, 6].
dropout_ratio (float, optional): The dropout ratio for upernet head. Default: 0.1.
aux_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
aux_channels (int, optional): The channels of inter layers in auxiliary head. Default: 256.
align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
dropout_ratio (float): Dropout ratio for upernet head. Default: 0.1.
pretrained (str, optional): The path or url of pretrained model. Default: None.
"""

Expand Down Expand Up @@ -72,6 +75,10 @@ def __init__(self,
self.pretrained = pretrained
self.init_weight()

def init_weight(self):
if self.pretrained is not None:
utils.load_entire_model(self, self.pretrained)

def forward(self, x):
feats = self.backbone(x)
feats = [feats[i] for i in self.backbone_indices]
Expand All @@ -85,10 +92,6 @@ def forward(self, x):
]
return logit_list

def init_weight(self):
if self.pretrained is not None:
utils.load_entire_model(self, self.pretrained)


class ConvBNReLU(nn.Layer):
def __init__(self,
Expand Down Expand Up @@ -118,12 +121,9 @@ class PPM(nn.Layer):
"""Pooling Pyramid Module used in PSPNet.

Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module.
pool_scales (tuple | list): Pooling scales used in PPM.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
channels (int): Output Channels after modules, before conv_seg.
act_cfg (dict): Config of activation layers.
align_corners (bool): align_corners argument of F.interpolate.
"""
Expand All @@ -145,7 +145,6 @@ def __init__(self, pool_scales, in_channels, channels, align_corners):
kernel_size=1)))

def forward(self, x):
"""Forward function."""
ppm_outs = []
for ppm in self.stages:
ppm_out = ppm(x)
Expand All @@ -159,16 +158,20 @@ def forward(self, x):


class UPerNetHead(nn.Layer):
"""Unified Perceptual Parsing for Scene Understanding.

This head is the implementation of `UPerNet
<https://arxiv.org/abs/1807.10221>`_.

"""
This head is the implementation of "Unified Perceptual Parsing for Scene Understanding".
This is heavily based on https://github.com/czczup/ViT-Adapter

Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module applied on the last feature. Default: (1, 2, 3, 6).
num_classes (int): The unique number of target classes.
in_channels (list[int]): The channels of input features.
channels (int, optional): The channels of inter layers in upernet head. Default: 512.
pool_scales (list, optional): The scales in PPM. Default: [1, 2, 3, 6].
dropout_ratio (float, optional): The dropout ratio for upernet head. Default: 0.1.
aux_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
aux_channels (int, optional): The channels of inter layers in auxiliary head. Default: 256.
align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
"""

def __init__(self,
Expand Down Expand Up @@ -204,7 +207,6 @@ def __init__(self,

self.fpn_bottleneck = ConvBNReLU(
len(in_channels) * channels, channels, 3, padding=1)

if dropout_ratio > 0:
self.dropout = nn.Dropout2D(dropout_ratio)
else:
Expand All @@ -219,7 +221,6 @@ def __init__(self,
aux_channels, num_classes, kernel_size=1)

def psp_forward(self, inputs):
"""Forward function of PSP module."""
x = inputs[-1]
psp_outs = [x]
psp_outs.extend(self.psp_modules(x))
Expand All @@ -228,25 +229,13 @@ def psp_forward(self, inputs):
return output

def forward(self, inputs):
"""Forward function."""
debug = False
if debug:
print('-------head 1----')
for x in inputs:
print(x.shape, x.numpy().mean())

# build laterals
laterals = [
lateral_conv(inputs[i])
for i, lateral_conv in enumerate(self.lateral_convs)
]
laterals.append(self.psp_forward(inputs))

if debug:
print('-------head 2----')
for x in laterals:
print(x.shape, x.numpy().mean())

# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
Expand All @@ -264,11 +253,6 @@ def forward(self, inputs):
]
fpn_outs.append(laterals[-1]) # append psp feature

if debug:
print('-------head 3----')
for x in fpn_outs:
print(x.shape, x.numpy().mean())

for i in range(used_backbone_levels - 1, 0, -1):
fpn_outs[i] = F.interpolate(
fpn_outs[i],
Expand All @@ -278,10 +262,6 @@ def forward(self, inputs):
fpn_outs = paddle.concat(fpn_outs, axis=1)
output = self.fpn_bottleneck(fpn_outs)

if debug:
print('-------head 4----')
print(output.shape, output.numpy().mean())

if self.dropout is not None:
output = self.dropout(output)
output = self.conv_seg(output)
Expand All @@ -292,10 +272,4 @@ def forward(self, inputs):
aux_output = self.aux_conv_seg(aux_output)
logits_list.append(aux_output)

if debug:
print('-------head 5----')
for x in logits_list:
print(x.shape, x.numpy().mean())
exit()

return logits_list