diff --git a/csrc/gpu/aten/operators/BatchNorm.cpp b/csrc/gpu/aten/operators/BatchNorm.cpp index 110365edc..cf596e03b 100644 --- a/csrc/gpu/aten/operators/BatchNorm.cpp +++ b/csrc/gpu/aten/operators/BatchNorm.cpp @@ -584,12 +584,14 @@ void batch_norm_elementwise( batch_norm_elemt_channels_first_template< scalar_t, accscalar_t, - int32_t>(out, self, *weight, *bias, mean_, invstd_); + int32_t>( + out, self.contiguous(), *weight, *bias, mean_, invstd_); } else { batch_norm_elemt_channels_first_template< scalar_t, scalar_t, - int32_t>(out, self, *weight, *bias, mean_, invstd_); + int32_t>( + out, self.contiguous(), *weight, *bias, mean_, invstd_); } }); return; @@ -607,7 +609,16 @@ void batch_norm_elementwise( (!mean_.defined() || mean_.is_contiguous()) && (!invstd_.defined() || invstd_.is_contiguous())) { batch_norm_elemt_channels_last_template( - out, self, *weight, *bias, mean_, invstd_); + out, + // It is a WA to fix Mobile-SSD convergence issue. + // TODO: Fully support: Check and convert activations with any + // shapes to align with kernel required memory layout. + self.dim() == 4 ? self.contiguous(at::MemoryFormat::ChannelsLast) + : self, + *weight, + *bias, + mean_, + invstd_); return; } } @@ -2858,13 +2869,25 @@ Tensor batch_norm_elementwise_backward_train( scalar_t, accscalar_t, int32_t>( - grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu); + grad_out.contiguous(), + input.contiguous(), + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu); } else { return batch_norm_backward_elemt_channels_first_template< scalar_t, scalar_t, int32_t>( - grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu); + grad_out.contiguous(), + input.contiguous(), + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu); } }); } @@ -2872,7 +2895,17 @@ Tensor batch_norm_elementwise_backward_train( if ((!weight.defined() || weight.is_contiguous()) && mean.is_contiguous() && invstd.is_contiguous()) { return batch_norm_backward_elemt_channels_last_template( - grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu); + // It is a WA to fix Mobile-SSD convergence issue. + grad_out.dim() == 4 + ? grad_out.contiguous(at::MemoryFormat::ChannelsLast) + : grad_out, + input.dim() == 4 ? input.contiguous(at::MemoryFormat::ChannelsLast) + : input, + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu); } } case Impl::General: { @@ -3091,7 +3124,18 @@ std::tuple batch_norm_backward_reduce_dispatch( (!weight.defined() || weight.is_contiguous()) && mean.is_contiguous() && invstd.is_contiguous()) { return batch_norm_backward_reduce_channels_last_template( - grad_output, input, mean, invstd, weight, input_g, weight_g, bias_g); + // It is a WA to fix Mobile-SSD convergence issue. + grad_output.dim() == 4 + ? grad_output.contiguous(at::MemoryFormat::ChannelsLast) + : grad_output, + input.dim() == 4 ? input.contiguous(at::MemoryFormat::ChannelsLast) + : input, + mean, + invstd, + weight, + input_g, + weight_g, + bias_g); } return IPEX_DISPATCH_FLOATING_TYPES_AND2( kHalf, @@ -3282,8 +3326,8 @@ std::tuple native_batch_norm_backward( scalar_t, accscalar_t, int32_t>( - grad_output, - input, + grad_output.contiguous(), + input.contiguous(), *weight, *running_mean, *running_var, @@ -3297,8 +3341,8 @@ std::tuple native_batch_norm_backward( scalar_t, scalar_t, int32_t>( - grad_output, - input, + grad_output.contiguous(), + input.contiguous(), *weight, *running_mean, *running_var, @@ -3913,7 +3957,17 @@ Tensor batch_norm_backward_elemt_dispatch( batch_norm_use_channels_last_kernels(self) && batch_norm_use_channels_last_kernels(input)) { return batch_norm_backward_elemt_channels_last_template( - self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); + // It is a WA to fix Mobile-SSD convergence issue. + self.dim() == 4 ? self.contiguous(at::MemoryFormat::ChannelsLast) + : self, + input.dim() == 4 ? input.contiguous(at::MemoryFormat::ChannelsLast) + : input, + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu, + count); } return IPEX_DISPATCH_FLOATING_TYPES_AND2( @@ -3938,13 +3992,27 @@ Tensor batch_norm_backward_elemt_dispatch( scalar_t, accscalar_t, int32_t>( - self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); + self.contiguous(), + input.contiguous(), + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu, + count); } else { return batch_norm_backward_elemt_channels_first_template< scalar_t, scalar_t, int32_t>( - self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); + self.contiguous(), + input.contiguous(), + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu, + count); } } else { if (is_half_float || is_bfloat16_float) { @@ -3952,13 +4020,27 @@ Tensor batch_norm_backward_elemt_dispatch( scalar_t, accscalar_t, int64_t>( - self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); + self.contiguous(), + input.contiguous(), + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu, + count); } else { return batch_norm_backward_elemt_channels_first_template< scalar_t, scalar_t, int64_t>( - self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); + self.contiguous(), + input.contiguous(), + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu, + count); } } });