@@ -755,7 +755,7 @@ def max_unpool1d(
755755 x (Tensor): The input tensor of unpooling operator which is a 3-D tensor with
756756 shape [N, C, L]. The format of input tensor is `"NCL"`,
757757 where `N` is batch size, `C` is the number of channels, `L` is
758- the length of the feature. The data type is float32 or float64 .
758+ the length of the feature. The data type is float32, float64 or int64 .
759759 indices (Tensor): The indices given out by maxpooling1d which is a 3-D tensor with
760760 shape [N, C, L]. The format of input tensor is `"NCL"` ,
761761 where `N` is batch size, `C` is the number of channels, `L` is
@@ -813,6 +813,8 @@ def max_unpool1d(
813813 # use 2d to implenment 1d should expand padding in advance.
814814 padding = _expand_low_nd_padding (padding )
815815
816+ if output_size is not None :
817+ output_size = output_size [:2 ] + [1 ] + output_size [2 :]
816818 output_size = _unpool_output_size (
817819 x , kernel_size , stride , padding , output_size
818820 )
@@ -863,12 +865,12 @@ def max_unpool2d(
863865 shape [N, C, H, W]. The format of input tensor is `"NCHW"`,
864866 where `N` is batch size, `C` is the number of channels,
865867 `H` is the height of the feature, and `W` is the width of the
866- feature. The data type if float32 or float64 .
868+ feature. The data type is float32, float64 or int64 .
867869 indices (Tensor): The indices given out by maxpooling2d which is a 4-D tensor with
868870 shape [N, C, H, W]. The format of input tensor is `"NCHW"` ,
869871 where `N` is batch size, `C` is the number of channels,
870872 `H` is the height of the feature, and `W` is the width of the
871- feature. The data type if float32 or float64.
873+ feature. The data type is float32 or float64.
872874 kernel_size (int|list|tuple): The unpool kernel size. If unpool kernel size is a tuple or list,
873875 it must contain an integer.
874876 stride (int|list|tuple): The unpool stride size. If unpool stride size is a tuple or list,
@@ -1011,7 +1013,7 @@ def max_unpool3d(
10111013 shape [N, C, D, H, W]. The format of input tensor is `"NCDHW"`,
10121014 where `N` is batch size, `C` is the number of channels, `D` is
10131015 the depth of the feature, `H` is the height of the feature,
1014- and `W` is the width of the feature. The data type is float32 or float64 .
1016+ and `W` is the width of the feature. The data type is float32, float64 or int64 .
10151017 indices (Tensor): The indices given out by maxpooling3d which is a 5-D tensor with
10161018 shape [N, C, D, H, W]. The format of input tensor is `"NCDHW"` ,
10171019 where `N` is batch size, `C` is the number of channels, `D` is
0 commit comments