Skip to content

Commit 0d7e780

Browse files
albanDfacebook-github-bot
authored andcommitted
Fix broadcasting of cdist backward (pytorch#56605)
Summary: Pull Request resolved: pytorch#56605 Fix pytorch#55370 Test Plan: Imported from OSS Reviewed By: ailzhang Differential Revision: D27939202 Pulled By: albanD fbshipit-source-id: a4ac50a7b504c24f47f5343414fb57523546a0c7
1 parent 3ddcc8d commit 0d7e780

File tree

3 files changed

+43
-15
lines changed

3 files changed

+43
-15
lines changed

aten/src/ATen/native/Distance.cpp

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,31 @@ Tensor _cdist_forward(const Tensor& x1, const Tensor& x2, const double p, c10::o
145145
return result;
146146
}
147147

148-
Tensor _cdist_backward(const Tensor& grad, const Tensor& x1, const Tensor& x2, const double p, const Tensor& cdist) {
148+
Tensor _cdist_backward(const Tensor& grad, const Tensor& _x1, const Tensor& _x2, const double p, const Tensor& cdist) {
149+
// Broadcasting might generate non-contiguous Tensors, so handle it before doing checks
150+
int64_t c1 = _x1.size(-1);
151+
int64_t c2 = _x2.size(-1);
152+
int64_t r1 = _x1.size(-2);
153+
int64_t r2 = _x2.size(-2);
154+
auto dim1 = _x1.dim();
155+
auto dim2 = _x2.dim();
156+
IntArrayRef batch_tensor1(_x1.sizes().data(), dim1 - 2);
157+
IntArrayRef batch_tensor2(_x2.sizes().data(), dim2 - 2);
158+
std::vector<int64_t> expand_batch_portion = infer_size(batch_tensor1, batch_tensor2);
159+
std::vector<int64_t> tensor1_expand_size(expand_batch_portion);
160+
tensor1_expand_size.insert(tensor1_expand_size.end(), {r1, c1});
161+
std::vector<int64_t> tensor2_expand_size(expand_batch_portion);
162+
tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2});
163+
164+
Tensor x1 = _x1;
165+
if (tensor1_expand_size != x1.sizes()) {
166+
x1 = x1.expand(tensor1_expand_size).contiguous();
167+
}
168+
Tensor x2 = _x2;
169+
if (tensor2_expand_size != x2.sizes()) {
170+
x2 = x2.expand(tensor2_expand_size).contiguous();
171+
}
172+
149173
TORCH_CHECK(x1.is_contiguous(), "_cdist_backward requires X1 to be contiguous");
150174
TORCH_CHECK(x2.is_contiguous(), "_cdist_backward requires X2 to be contiguous");
151175
TORCH_CHECK(cdist.is_contiguous(), "_cdist_backward requires dist to be contiguous");
@@ -156,13 +180,17 @@ Tensor _cdist_backward(const Tensor& grad, const Tensor& x1, const Tensor& x2, c
156180
TORCH_CHECK(device1 == kCPU || device1 == kCUDA, "_cdist_backward only supports CPU and CUDA devices, X1 got: ", device1);
157181
auto device2 = x2.device().type();
158182
TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "_cdist_backward only supports CPU and CUDA devices, X2 got: ", device2);
159-
IntArrayRef batch_tensor1(x1.sizes().data(), std::max<int64_t>(x1.dim() - 2, 0));
160-
const int64_t batch_product = c10::multiply_integers(batch_tensor1);
183+
184+
// Compute the linearized batch size
185+
const int64_t batch_product = c10::multiply_integers(expand_batch_portion);
186+
161187
Tensor grad_x1 =
162-
at::empty_like(x1, x1.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT)
163-
.view({batch_product, n, m});
188+
at::empty({batch_product, n, m}, x1.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
164189
cdist_backward_stub(device1, grad_x1, grad, x1, x2, p, cdist);
165-
return grad_x1;
190+
191+
// Use x1.size() here and not the original size of _x1.size() as this gradient is not taking broadcasting into account
192+
// Broadcasting will be handled automatically by the autograd engine
193+
return grad_x1.view(x1.sizes());
166194
}
167195

168196
Tensor _pdist_forward(const Tensor& self, const double p) {

aten/src/ATen/native/cuda/DistanceKernel.cu

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,8 @@ void cdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor
331331
const int64_t r1 = x1.size(-2);
332332
const int64_t r2 = x2.size(-2);
333333
const int64_t m = x1.size(-1);
334-
int64_t batch = x1.dim() > 2 ? x1.size(0) : 1;
334+
// Just like we do in the CPU code, assume that result is always batched
335+
int64_t batch = result.size(0);
335336
const int block_x = 64;
336337
const int block_y = 16;
337338
const int grid_x = (m + block_x * 8 - 1) / (block_x * 8);
@@ -352,7 +353,7 @@ void cdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor
352353
//we call grad.contiguous() before backward, so stride is guaranteed to be 1
353354
const int64_t gs = 1;
354355

355-
Tensor buffer = (x1.dim() > 2) ? at::empty({batch, r2, r1, m}, result.options()) : at::empty({r2, r1, m}, result.options());
356+
Tensor buffer = at::empty({batch, r2, r1, m}, result.options());
356357
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist_cuda_backward", [&] {
357358
if (p == 1.0) {
358359
cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(buffer.data_ptr<scalar_t>(),
@@ -382,11 +383,7 @@ void cdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor
382383
}
383384
});
384385

385-
if (x1.dim() > 2) {
386-
at::sum_out(result, buffer, 1);
387-
} else {
388-
at::sum_out(result, buffer, 0);
389-
}
386+
at::sum_out(result, buffer, 1);
390387

391388
}
392389

torch/testing/_internal/common_methods_invocations.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,7 @@ def sample_inputs_broadcast_to(op_info, device, dtype, requires_grad, **kwargs):
932932
args=(shape,)) for size, shape in test_cases)
933933

934934
def sample_inputs_cdist(op_info, device, dtype, requires_grad, **kwargs):
935+
small_S = 2
935936
test_cases = (
936937
((S, S, 2), (S, S + 1, 2)),
937938
((S, S), (S, S)),
@@ -942,8 +943,10 @@ def sample_inputs_cdist(op_info, device, dtype, requires_grad, **kwargs):
942943
((1, 1), (S, 1)),
943944
# TODO enable that as this causes "Floating point exception (core dumped)"
944945
# ((0, 5), (4, 5)),
945-
# TODO enable that as this causes https://github.com/pytorch/pytorch/issues/55370
946-
# ((S, S, 21, 2), (S, S, 22, 2))
946+
# Using S here would make this one test take 9s
947+
((small_S, small_S, small_S + 1, 2), (small_S, small_S, small_S + 2, 2)),
948+
((small_S, 1, 1, small_S), (1, small_S, small_S)),
949+
((1, 1, small_S), (small_S, 1, small_S, small_S)),
947950
)
948951

949952
samples = []

0 commit comments

Comments
 (0)