Skip to content

Commit

Permalink
Fix sfnet export issue
Browse files Browse the repository at this point in the history
  • Loading branch information
nepeplwu committed Jun 23, 2021
1 parent 34c1bbf commit 30860e6
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions paddleseg/models/sfnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(self, x):
logit_list = [
F.interpolate(
logit,
x.shape[2:],
paddle.shape(x)[2:],
mode='bilinear',
align_corners=self.align_corners) for logit in logit_list
]
Expand Down Expand Up @@ -165,7 +165,7 @@ def forward(self, conv_out):
out.append(self.dsn[i](f))

fpn_feature_list.reverse()
output_size = fpn_feature_list[0].shape[2:]
output_size = paddle.shape(fpn_feature_list[0])[2:]
fusion_list = [fpn_feature_list[0]]

for i in range(1, len(fpn_feature_list)):
Expand Down Expand Up @@ -205,24 +205,25 @@ def __init__(self, inplane, outplane, kernel_size=3):
padding=1,
bias_attr=False)

def flow_warp(self, inputs, flow, size):
out_h, out_w = size
n, c, h, w = inputs.shape
norm = paddle.to_tensor([[[[out_w, out_h]]]]).astype('float32')
h = paddle.linspace(-1.0, 1.0, out_h).reshape([-1, 1]).tile([1, out_w])
w = paddle.linspace(-1.0, 1.0, out_w).tile([out_h, 1])
grid = paddle.concat([paddle.unsqueeze(w, 2),
paddle.unsqueeze(h, 2)], 2)
grid = grid.tile([n, 1, 1, 1]).astype('float32')
grid = grid + flow.transpose([0, 2, 3, 1]) / norm
output = F.grid_sample(inputs, grid)
def flow_warp(self, input, flow, size):
input_shape = paddle.shape(input)
norm = size[::-1].reshape([1, 1, 1, -1])
norm.stop_gradient = True
h_grid = paddle.linspace(-1.0, 1.0, size[0]).reshape([-1, 1])
h_grid = h_grid.tile([size[1]])
w_grid = paddle.linspace(-1.0, 1.0, size[1]).reshape([-1, 1])
w_grid = w_grid.tile([size[0]]).transpose([1, 0])
grid = paddle.concat([w_grid.unsqueeze(2), h_grid.unsqueeze(2)], axis=2)
grid.unsqueeze(0).tile([input_shape[0], 1, 1, 1])
grid = grid + paddle.transpose(flow, (0, 2, 3, 1)) / norm

output = F.grid_sample(input, grid)
return output

def forward(self, x):
low_feature, h_feature = x
h_feature_orign = h_feature
h, w = low_feature.shape[2:]
size = (h, w)
size = paddle.shape(low_feature)[2:]
low_feature = self.down_l(low_feature)
h_feature = self.down_h(h_feature)
h_feature = F.interpolate(
Expand Down

0 comments on commit 30860e6

Please sign in to comment.