@@ -33,8 +33,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
33
33
self .ref_thickness = ref_thickness
34
34
self .shift_padding_zero = shift_padding_zero
35
35
36
- def alignshift (self , input , fold , thickness , padding_zero = True ):
37
- alph = 1 - (self . ref_thickness / thickness ).view (- 1 , 1 ,1 , 1 , 1 ).clamp_ (0. , 1. )
36
+ def alignshift (self , input , fold , ref_thickness , thickness , padding_zero = True ):
37
+ alph = 1 - (ref_thickness / thickness ).view (- 1 , 1 ,1 , 1 , 1 ).clamp_ (0. , 1. )
38
38
out = torch .zeros_like (input )
39
39
out [:, :fold , :- 1 ] = input [:, :fold , :- 1 ] * alph + input [:, :fold , 1 :] * (1 - alph )
40
40
out [:, fold : 2 * fold , 1 :] = input [:, fold : 2 * fold , 1 :] * alph + \
@@ -49,7 +49,7 @@ def align_shift(self, x, fold, ref_thickness, thickness, padding_zero, inplace):
49
49
if inplace :
50
50
x = inplace_alignshift (x , fold , ref_thickness , thickness , padding_zero )
51
51
else :
52
- x = self .alignshift (x , fold , thickness , padding_zero )
52
+ x = self .alignshift (x , fold , ref_thickness , thickness , padding_zero )
53
53
return x
54
54
55
55
def forward (self , input , thickness = None ):
@@ -67,7 +67,7 @@ def extra_repr(self):
67
67
68
68
class InplaceAlignShift (torch .autograd .Function ):
69
69
@staticmethod
70
- def forward (ctx , input , fold , align_spacing , thickness , padding_zero = True ):
70
+ def forward (ctx , input , fold , ref_thickness , thickness , padding_zero = True ):
71
71
'''
72
72
@params:
73
73
input: BxCxDxHxW
@@ -79,7 +79,7 @@ def forward(ctx, input, fold, align_spacing, thickness, padding_zero=True):
79
79
n , c , t , h , w = input .size ()
80
80
ctx .fold_ = fold
81
81
ctx .padding_zero = padding_zero
82
- alph = 1 - (align_spacing / thickness ).view (- 1 , 1 ,1 , 1 , 1 ).clamp_ (0. , 1. ) ##把小于align_spacing的当作align_spacing ,不做插值处理
82
+ alph = 1 - (ref_thickness / thickness ).view (- 1 , 1 ,1 , 1 , 1 ).clamp_ (0. , 1. ) ##把小于ref_thickness的当作ref_thickness ,不做插值处理
83
83
ctx .alph_ = alph
84
84
input .data [:, :fold , :- 1 ] = input .data [:, :fold , :- 1 ] * alph + input .data [:, :fold , 1 :] * (1 - alph )
85
85
input .data [:, fold :2 * fold , 1 :] = input .data [:, fold :2 * fold , 1 :] * alph + \
0 commit comments