Skip to content

Commit

Permalink
Add support for non-affine batch norm with float stats and half inputs (
Browse files Browse the repository at this point in the history
pytorch#22750)

Summary:
This PR creates support for non-affine batch norm with float running estimates and half inputs.
Changed were made similar to pytorch#16735.

I couldn't find a specific test for `SyncBatchNorm`, so I used [this code](https://gist.github.com/ptrblck/ab45bfcde6df55ac28a7be18531f4718) to test it.

cc ngimel
Pull Request resolved: pytorch#22750

Differential Revision: D17119965

Pulled By: ezyang

fbshipit-source-id: 2e8c5d63fc3c636b8a1338c43c9c101a0f5e9b22
  • Loading branch information
ptrblck authored and facebook-github-bot committed Aug 29, 2019
1 parent fe922a2 commit 8640aef
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 153 deletions.
88 changes: 74 additions & 14 deletions aten/src/ATen/native/cuda/Normalization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,45 @@ namespace at { namespace native {
std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda(const Tensor& self, const Tensor& weight, const Tensor& bias,
const Tensor& running_mean, const Tensor& running_var, bool train, double momentum, double epsilon) {
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_cuda", [&] {
auto mean_st = running_mean.dtype();
auto var_st = running_var.dtype();
TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types");
bool is_half_float = std::is_same<scalar_t, at::Half>::value && mean_st == at::kFloat;
if (cuda::detail::canUse32BitIndexMath(self)) {
return batch_norm_cuda_template<scalar_t, int32_t>(self, weight, bias, running_mean, running_var, train, momentum, epsilon);
if (is_half_float) {
return batch_norm_cuda_template<at::Half, float, int32_t>(self, weight, bias, running_mean, running_var, train, momentum, epsilon);
} else {
return batch_norm_cuda_template<scalar_t, scalar_t, int32_t>(self, weight, bias, running_mean, running_var, train, momentum, epsilon);
}
} else {
return batch_norm_cuda_template<scalar_t, int64_t>(self, weight, bias, running_mean, running_var, train, momentum, epsilon);
if (is_half_float) {
return batch_norm_cuda_template<at::Half, float, int64_t>(self, weight, bias, running_mean, running_var, train, momentum, epsilon);
} else {
return batch_norm_cuda_template<scalar_t, scalar_t, int64_t>(self, weight, bias, running_mean, running_var, train, momentum, epsilon);
}
}
});
}

std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& self, const Tensor& weight, const Tensor& running_mean, const Tensor& running_var,
const Tensor& save_mean, const Tensor& save_invstd, bool train, double epsilon, std::array<bool,3> grad_input_mask) {
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_backward_cuda", [&] {
auto mean_st = running_mean.dtype();
auto var_st = running_var.dtype();
TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types");
bool is_half_float = std::is_same<scalar_t, at::Half>::value && mean_st == at::kFloat;
if (cuda::detail::canUse32BitIndexMath(self)) {
return batch_norm_backward_cuda_template<scalar_t, int32_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask);
if (is_half_float) {
return batch_norm_backward_cuda_template<at::Half, float, int32_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask);
} else {
return batch_norm_backward_cuda_template<scalar_t, scalar_t, int32_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask);
}
} else {
return batch_norm_backward_cuda_template<scalar_t, int64_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask);
if (is_half_float) {
return batch_norm_backward_cuda_template<at::Half, float, int64_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask);
} else {
return batch_norm_backward_cuda_template<scalar_t, scalar_t, int64_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask);
}
}
});
}
Expand All @@ -37,10 +61,22 @@ std::tuple<Tensor, Tensor> batch_norm_stats_cuda(const Tensor& self, double epsi
Tensor batch_norm_elemt_cuda(const Tensor& self, const Tensor& weight, const Tensor& bias,
const Tensor& mean, const Tensor& invstd, double epsilon) {
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_elemt", [&] {
auto mean_st = mean.dtype();
auto invstd_st = invstd.dtype();
TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types");
bool is_half_float = std::is_same<scalar_t, at::Half>::value && mean_st == at::kFloat;
if (cuda::detail::canUse32BitIndexMath(self)) {
return batch_norm_elemt_cuda_template<scalar_t, int32_t>(self, weight, bias, mean, invstd, epsilon);
if (is_half_float) {
return batch_norm_elemt_cuda_template<at::Half, float, int32_t>(self, weight, bias, mean, invstd, epsilon);
} else {
return batch_norm_elemt_cuda_template<scalar_t, scalar_t, int32_t>(self, weight, bias, mean, invstd, epsilon);
}
} else {
return batch_norm_elemt_cuda_template<scalar_t, int64_t>(self, weight, bias, mean, invstd, epsilon);
if (is_half_float) {
return batch_norm_elemt_cuda_template<at::Half, float, int64_t>(self, weight, bias, mean, invstd, epsilon);
} else {
return batch_norm_elemt_cuda_template<scalar_t, scalar_t, int64_t>(self, weight, bias, mean, invstd, epsilon);
}
}
});
}
Expand All @@ -56,8 +92,8 @@ std::tuple<Tensor, Tensor> batch_norm_gather_stats_cuda(const Tensor& self, cons
std::tuple<Tensor, Tensor> batch_norm_gather_stats_with_counts_cuda(const Tensor& self, const Tensor& mean, const Tensor& invstd, const Tensor& running_mean,
const Tensor& running_var, double momentum, double epsilon, IntArrayRef counts) {
Tensor counts_ = at::from_blob((void*)counts.data(), {(int64_t)counts.size()}, self.options().dtype(at::kLong).device(at::kCPU));
counts_ = counts_.to(self.device()).to(self.dtype());
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_update_stats_cuda", [&] {
counts_ = counts_.to(self.device()).to(running_mean.dtype());
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(running_mean.scalar_type(), "batch_norm_update_stats_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
if (cuda::detail::canUse32BitIndexMath(self)) {
return batch_norm_gather_stats_cuda_template<scalar_t, accscalar_t, int32_t>(mean, invstd, running_mean, running_var, momentum, epsilon, counts_);
Expand All @@ -67,24 +103,48 @@ std::tuple<Tensor, Tensor> batch_norm_gather_stats_with_counts_cuda(const Tensor
});
}

std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda(const Tensor& self, const Tensor& input, const Tensor& mean,
const Tensor& invstd, bool input_g, bool weight_g, bool bias_g) {
std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda(const Tensor& self, const Tensor& input, const Tensor& mean, const Tensor& invstd,
const Tensor& weight, bool input_g, bool weight_g, bool bias_g) {
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_backward_reduce", [&] {
auto mean_st = mean.dtype();
auto invstd_st = invstd.dtype();
TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types");
bool is_half_float = std::is_same<scalar_t, at::Half>::value && mean_st == at::kFloat;
if (cuda::detail::canUse32BitIndexMath(self)) {
return batch_norm_backward_reduce_cuda_template<scalar_t, int32_t>(self, input, mean, invstd, input_g, weight_g, bias_g);
if (is_half_float) {
return batch_norm_backward_reduce_cuda_template<at::Half, float, int32_t>(self, input, mean, invstd, weight, input_g, weight_g, bias_g);
} else {
return batch_norm_backward_reduce_cuda_template<scalar_t, scalar_t, int32_t>(self, input, mean, invstd, weight, input_g, weight_g, bias_g);
}
} else {
return batch_norm_backward_reduce_cuda_template<scalar_t, int64_t>(self, input, mean, invstd, input_g, weight_g, bias_g);
if (is_half_float) {
return batch_norm_backward_reduce_cuda_template<at::Half, float, int64_t>(self, input, mean, invstd, weight, input_g, weight_g, bias_g);
} else {
return batch_norm_backward_reduce_cuda_template<scalar_t, scalar_t, int64_t>(self, input, mean, invstd, weight, input_g, weight_g, bias_g);
}
}
});
}

Tensor batch_norm_backward_elemt_cuda(const Tensor& self, const Tensor& input, const Tensor& mean, const Tensor& invstd,
const Tensor& weight, const Tensor& mean_dy, const Tensor& mean_dy_xmu) {
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_backward_elemt", [&] {
auto mean_st = mean.dtype();
auto invstd_st = invstd.dtype();
TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types");
bool is_half_float = std::is_same<scalar_t, at::Half>::value && mean_st == at::kFloat;
if (cuda::detail::canUse32BitIndexMath(self)) {
return batch_norm_backward_elemt_cuda_template<scalar_t, int32_t>(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu);
if (is_half_float) {
return batch_norm_backward_elemt_cuda_template<at::Half, float, int32_t>(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu);
} else {
return batch_norm_backward_elemt_cuda_template<scalar_t, scalar_t, int32_t>(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu);
}
} else {
return batch_norm_backward_elemt_cuda_template<scalar_t, int64_t>(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu);
if (is_half_float) {
return batch_norm_backward_elemt_cuda_template<at::Half, float, int64_t>(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu);
} else {
return batch_norm_backward_elemt_cuda_template<scalar_t, scalar_t, int64_t>(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu);
}
}
});
}
Expand Down
Loading

0 comments on commit 8640aef

Please sign in to comment.