Skip to content

Commit

Permalink
[Cherry-pick] #1511 #1449 #1518 (#1572)
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyutang authored Nov 26, 2021
1 parent ff7a304 commit 5e6cab5
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 98 deletions.
34 changes: 17 additions & 17 deletions paddleseg/models/ginet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@
class GINet(nn.Layer):
"""
The GINet implementation based on PaddlePaddle.
The original article refers to
Wu, Tianyi, Yu Lu, Yu Zhu, Chuang Zhang, Ming Wu, Zhanyu Ma, and Guodong Guo. "GINet: Graph interaction network for scene parsing." In European Conference on Computer Vision, pp. 34-51. Springer, Cham, 2020.
(https://arxiv.org/pdf/2009.06160).
Args:
num_classes (int): The unique number of target classes.
backbone (Paddle.nn.Layer): Backbone network.
Expand Down Expand Up @@ -71,6 +69,7 @@ def __init__(self,

def base_forward(self, x):
feat_list = self.backbone(x)

c1, c2, c3, c4 = [feat_list[i] for i in self.backbone_indices]

if self.jpu:
Expand All @@ -79,7 +78,7 @@ def base_forward(self, x):
return c1, c2, c3, c4

def forward(self, x):
_, _, h, w = x.shape
_, _, h, w = paddle.shape(x)
_, _, c3, c4 = self.base_forward(x)

logit_list = []
Expand Down Expand Up @@ -115,6 +114,7 @@ def __init__(self, in_channels, nclass):
shape=self.inp.shape,
dtype=str(self.inp.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(self.inp))
self.inp.stop_gradient = True

self.fc1 = nn.Sequential(
nn.Linear(300, 128), nn.BatchNorm1D(128), nn.ReLU())
Expand All @@ -137,8 +137,9 @@ def __init__(self, in_channels, nclass):
nn.Dropout(0.1), nn.Conv2D(inter_channels, nclass, 1))

def forward(self, x):
B, C, H, W = x.shape
inp = self.inp.detach()

B, C, H, W = paddle.shape(x)
inp = self.inp

inp = self.fc1(inp)
inp = self.fc2(inp).unsqueeze(axis=0).transpose((0, 2, 1))\
Expand Down Expand Up @@ -172,20 +173,19 @@ def __init__(self, in_channels, num_state=256, num_node=84, nclass=59):

def forward(self, x, inp):
B = self.conv_theta(x)
sizeB = B.shape
B = B.reshape((sizeB[0], sizeB[1], -1))
sizeB = paddle.shape(B)
B = paddle.flatten(B, 2, 3)

sizex = x.shape
sizex = paddle.shape(x)
x_reduce = self.conv_phi(x)
x_reduce = x_reduce.reshape((sizex[0], -1, sizex[2] * sizex[3]))\
.transpose((0, 2, 1))

x_reduce = paddle.flatten(x_reduce, 2, 3).transpose((0, 2, 1))

V = paddle.bmm(B, x_reduce).transpose((0, 2, 1))
V = paddle.divide(
V, paddle.to_tensor([sizex[2] * sizex[3]], dtype='float32'))
V = paddle.divide(V, (sizex[2] * sizex[3]).astype('float32'))

class_node, new_V = self.graph(inp, V)
D = B.reshape((sizeB[0], -1, sizeB[2] * sizeB[3])).transpose((0, 2, 1))
D = B.transpose((0, 2, 1))
Y = paddle.bmm(D, new_V.transpose((0, 2, 1)))
Y = Y.transpose((0, 2, 1)).reshape((sizex[0], self.num_state, \
sizex[2], -1))
Expand All @@ -205,11 +205,11 @@ def __init__(self, num_state, num_node, num_class):
self.gamma_vis = paddle.zeros([num_node])
self.gamma_word = paddle.zeros([num_class])
self.gamma_vis = paddle.create_parameter(
shape=self.gamma_vis.shape,
shape=paddle.shape(self.gamma_vis),
dtype=str(self.gamma_vis.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(self.gamma_vis))
self.gamma_word = paddle.create_parameter(
shape=self.gamma_word.shape,
shape=paddle.shape(self.gamma_word),
dtype=str(self.gamma_word.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(self.gamma_word))

Expand Down Expand Up @@ -270,8 +270,8 @@ def __init__(self, in_dim):
self.softmax_word = nn.Softmax(axis=-2)

def forward(self, word, vis_node):
m_batchsize, C, Nc = word.shape
m_batchsize, C, Nn = vis_node.shape
m_batchsize, C, Nc = paddle.shape(word)
m_batchsize, C, Nn = paddle.shape(vis_node)

proj_query = self.query_conv(word).reshape((m_batchsize, -1, Nc))\
.transpose((0, 2, 1))
Expand Down
2 changes: 1 addition & 1 deletion paddleseg/models/layers/layer_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def forward(self, *inputs):
self.conv4(inputs[-2]),
self.conv3(inputs[-3])
]
size = feats[-1].shape[2:]
size = paddle.shape(feats[-1])[2:]
feats[-2] = F.interpolate(
feats[-2], size, mode='bilinear', align_corners=True)
feats[-3] = F.interpolate(
Expand Down
77 changes: 40 additions & 37 deletions paddleseg/models/pointrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def __init__(
self.importance_sample_ratio = importance_sample_ratio
self.scale_factor = scale_factor
self.subdivision_steps = subdivision_steps
self.subdivision_num_points = subdivision_num_points
self.subdivision_num_points = paddle.to_tensor(subdivision_num_points, dtype="int32")
self.dropout_ratio = dropout_ratio
self.coarse_pred_each_layer = coarse_pred_each_layer
self.align_corners = align_corners
Expand Down Expand Up @@ -332,7 +332,7 @@ def _transform_inputs(self, inputs):
upsampled_inputs = [
F.interpolate(
x,
size=inputs[0].shape[2:],
size=paddle.shape(inputs[0])[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
Expand Down Expand Up @@ -367,7 +367,7 @@ def get_points_train(self, seg_logits, uncertainty_func): # finish
importance_sample_ratio = self.importance_sample_ratio
assert oversample_ratio >= 1
assert 0 <= importance_sample_ratio <= 1
batch_size = seg_logits.shape[0]
batch_size = paddle.shape(seg_logits)[0]
num_sampled = int(num_points * oversample_ratio)
point_coords = paddle.rand([batch_size, num_sampled, 2])
point_logits = point_sample(seg_logits, point_coords)
Expand Down Expand Up @@ -419,12 +419,14 @@ def get_points_test(self, seg_logits, uncertainty_func): # finish

num_points = self.subdivision_num_points
uncertainty_map = uncertainty_func(seg_logits)
batch_size, _, height, width = uncertainty_map.shape
batch_size = paddle.shape(uncertainty_map)[0]
height = paddle.shape(uncertainty_map)[2]
width = paddle.shape(uncertainty_map)[3]
h_step = 1.0 / height
w_step = 1.0 / width

uncertainty_map = uncertainty_map.reshape([batch_size, height * width])
num_points = min(height * width, num_points)
num_points = paddle.min(paddle.concat([height * width, num_points]))
point_indices = paddle.topk(uncertainty_map, num_points, axis=1)[1]
point_coords = paddle.zeros([batch_size, num_points, 2],
dtype='float32')
Expand All @@ -446,10 +448,10 @@ def scatter_paddle(self, refined_seg_logits, point_indices, point_logits):
scattered refined_seg_logits(Tensor).
"""

original_shape = refined_seg_logits.shape # [batch_size, channels, height * width]
original_shape = paddle.shape(refined_seg_logits) # [batch_size, channels, height * width]
new_refined_seg_logits = refined_seg_logits.flatten(0, 1) # [N*C,H*W]
offsets = (paddle.arange(new_refined_seg_logits.shape[0]) *
new_refined_seg_logits.shape[1]).unsqueeze(-1) # [N*C,1]
offsets = (paddle.arange(paddle.shape(new_refined_seg_logits)[0]) *
paddle.shape(new_refined_seg_logits)[1]).unsqueeze(-1) # [N*C,1]
point_indices = point_indices.flatten(0, 1) # [N*C,H*W]
new_point_indices = (point_indices + offsets).flatten()
point_logits = point_logits.flatten() # [N*C*H*W]
Expand All @@ -460,6 +462,25 @@ def scatter_paddle(self, refined_seg_logits, point_indices, point_logits):
overwrite=True)
return refined_seg_logits.reshape(shape=original_shape)

def forward_train(self, x, prev_output):
with paddle.no_grad():
points = self.get_points_train(prev_output, calculate_uncertainty)

fine_grained_point_feats = self._get_fine_grained_point_feats(
x, points) # [2, 256, 2048]
coarse_point_feats = self._get_coarse_point_feats(
prev_output, points) # [2, 19, 2048]
# forward for train
fusion_point_feats = paddle.concat(
[fine_grained_point_feats, coarse_point_feats], axis=1)
for fc in self.fcs:
fusion_point_feats = fc(fusion_point_feats)
if self.coarse_pred_each_layer:
fusion_point_feats = paddle.concat(
(fusion_point_feats, coarse_point_feats), axis=1)
point_logits = self.cls_seg(fusion_point_feats)
return [point_logits, points] # for points loss

def forward(self, inputs, prev_output):
"""
Forward function.
Expand All @@ -475,24 +496,7 @@ def forward(self, inputs, prev_output):
prev_output = prev_output[0]
x = self._transform_inputs(inputs)
if self.training:
with paddle.no_grad():
points = self.get_points_train(prev_output,
calculate_uncertainty)

fine_grained_point_feats = self._get_fine_grained_point_feats(
x, points) # [2, 256, 2048]
coarse_point_feats = self._get_coarse_point_feats(
prev_output, points) # [2, 19, 2048]
# forward for train
fusion_point_feats = paddle.concat(
[fine_grained_point_feats, coarse_point_feats], axis=1)
for fc in self.fcs:
fusion_point_feats = fc(fusion_point_feats)
if self.coarse_pred_each_layer:
fusion_point_feats = paddle.concat(
(fusion_point_feats, coarse_point_feats), axis=1)
point_logits = self.cls_seg(fusion_point_feats)
return [point_logits, points] # for points loss
return self.forward_train(x, prev_output)
else:
refined_seg_logits = prev_output.clone()
for _ in range(self.subdivision_steps):
Expand All @@ -501,7 +505,8 @@ def forward(self, inputs, prev_output):
scale_factor=self.scale_factor,
mode='bilinear',
align_corners=self.align_corners)
batch_size, channels, height, width = refined_seg_logits.shape

save_shape = paddle.shape(refined_seg_logits)
point_indices, points = self.get_points_test(
refined_seg_logits, calculate_uncertainty)
fine_grained_point_feats = self._get_fine_grained_point_feats(
Expand All @@ -518,14 +523,13 @@ def forward(self, inputs, prev_output):
(fusion_point_feats, coarse_point_feats), axis=1)
point_logits = self.cls_seg(fusion_point_feats)
point_indices = paddle.unsqueeze(point_indices, axis=1)
point_indices = paddle.expand(point_indices, [-1, channels, -1])
refined_seg_logits = refined_seg_logits.reshape(
[batch_size, channels, height * width])
point_indices = paddle.expand(point_indices, [-1, save_shape[1], -1])

refined_seg_logits = paddle.flatten(refined_seg_logits, 2)
refined_seg_logits = self.scatter_paddle(
refined_seg_logits, point_indices,
point_logits) # 2->height * width dim
refined_seg_logits = refined_seg_logits.reshape(
[batch_size, channels, height, width])
refined_seg_logits = refined_seg_logits.reshape(save_shape)
return [refined_seg_logits]


Expand Down Expand Up @@ -626,7 +630,7 @@ def _transform_inputs(self, inputs):
upsampled_inputs = [
F.interpolate(
x,
size=inputs[0].shape[2:],
size=paddle.shape(inputs[0])[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
Expand All @@ -644,7 +648,7 @@ def forward(self, inputs):
for i in range(1, len(self.feature_strides)):
output = output + F.interpolate(
self.scale_heads[i](x[i]),
size=output.shape[2:],
size=paddle.shape(output)[2:],
mode='bilinear',
align_corners=self.align_corners)
output = self.cls_seg(output)
Expand Down Expand Up @@ -767,10 +771,9 @@ def __init__(self,

def forward(self, x):
if not self.size:
size = [int(t * self.scale_factor) for t in x.shape[-2:]]
return F.interpolate(x, None, self.scale_factor, self.mode, self.align_corners)
else:
size = self.size
return F.interpolate(x, size, None, self.mode, self.align_corners)
return F.interpolate(x, self.size, None, self.mode, self.align_corners)


def point_sample(input, points, align_corners=False, **kwargs):
Expand Down
71 changes: 28 additions & 43 deletions paddleseg/models/stdcseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,57 +57,42 @@ def __init__(self,
self.use_boundary_16 = use_boundary_16
self.cp = ContextPath(backbone, use_conv_last=use_conv_last)
self.ffm = FeatureFusionModule(384, 256)
self.conv_out = BiSeNetOutput(256, 256, num_classes)
self.conv_out16 = BiSeNetOutput(128, 64, num_classes)
self.conv_out32 = BiSeNetOutput(128, 64, num_classes)
self.conv_out_sp16 = BiSeNetOutput(512, 64, 1)
self.conv_out_sp8 = BiSeNetOutput(256, 64, 1)
self.conv_out_sp4 = BiSeNetOutput(64, 64, 1)
self.conv_out_sp2 = BiSeNetOutput(32, 64, 1)
self.conv_out = SegHead(256, 256, num_classes)
self.conv_out8 = SegHead(128, 64, num_classes)
self.conv_out16 = SegHead(128, 64, num_classes)
self.conv_out_sp16 = SegHead(512, 64, 1)
self.conv_out_sp8 = SegHead(256, 64, 1)
self.conv_out_sp4 = SegHead(64, 64, 1)
self.conv_out_sp2 = SegHead(32, 64, 1)
self.pretrained = pretrained
self.init_weight()

def forward(self, x):
x_hw = paddle.shape(x)[2:]
feat_res2, feat_res4, feat_res8, feat_res16, feat_cp8, feat_cp16 = self.cp(
x)
feat_res2, feat_res4, feat_res8, _, feat_cp8, feat_cp16 = self.cp(x)

logit_list = []
if self.training:
feat_out_sp2 = self.conv_out_sp2(feat_res2)
feat_out_sp4 = self.conv_out_sp4(feat_res4)
feat_out_sp8 = self.conv_out_sp8(feat_res8)
feat_out_sp16 = self.conv_out_sp16(feat_res16)
feat_fuse = self.ffm(feat_res8, feat_cp8)
feat_out = self.conv_out(feat_fuse)
feat_out16 = self.conv_out16(feat_cp8)
feat_out32 = self.conv_out32(feat_cp16)
feat_out = F.interpolate(
feat_out, x_hw, mode='bilinear', align_corners=True)
feat_out16 = F.interpolate(
feat_out16, x_hw, mode='bilinear', align_corners=True)
feat_out32 = F.interpolate(
feat_out32, x_hw, mode='bilinear', align_corners=True)

if self.use_boundary_2 and self.use_boundary_4 and self.use_boundary_8:
logit_list = [
feat_out, feat_out16, feat_out32, feat_out_sp2,
feat_out_sp4, feat_out_sp8
]

if (not self.use_boundary_2
) and self.use_boundary_4 and self.use_boundary_8:
logit_list = [
feat_out, feat_out16, feat_out32, feat_out_sp4, feat_out_sp8
]

if (not self.use_boundary_2) and (
not self.use_boundary_4) and self.use_boundary_8:
logit_list = [feat_out, feat_out16, feat_out32, feat_out_sp8]

if (not self.use_boundary_2) and (not self.use_boundary_4) and (
not self.use_boundary_8):
logit_list = [feat_out, feat_out16, feat_out32]
feat_out8 = self.conv_out8(feat_cp8)
feat_out16 = self.conv_out16(feat_cp16)

logit_list = [feat_out, feat_out8, feat_out16]
logit_list = [
F.interpolate(x, x_hw, mode='bilinear', align_corners=True)
for x in logit_list
]

if self.use_boundary_2:
feat_out_sp2 = self.conv_out_sp2(feat_res2)
logit_list.append(feat_out_sp2)
if self.use_boundary_4:
feat_out_sp4 = self.conv_out_sp4(feat_res4)
logit_list.append(feat_out_sp4)
if self.use_boundary_8:
feat_out_sp8 = self.conv_out_sp8(feat_res8)
logit_list.append(feat_out_sp8)
else:
feat_fuse = self.ffm(feat_res8, feat_cp8)
feat_out = self.conv_out(feat_fuse)
Expand All @@ -122,9 +107,9 @@ def init_weight(self):
utils.load_entire_model(self, self.pretrained)


class BiSeNetOutput(nn.Layer):
class SegHead(nn.Layer):
def __init__(self, in_chan, mid_chan, n_classes):
super(BiSeNetOutput, self).__init__()
super(SegHead, self).__init__()
self.conv = layers.ConvBNReLU(
in_chan, mid_chan, kernel_size=3, stride=1, padding=1)
self.conv_out = nn.Conv2D(
Expand Down

0 comments on commit 5e6cab5

Please sign in to comment.