@@ -145,7 +145,31 @@ Tensor _cdist_forward(const Tensor& x1, const Tensor& x2, const double p, c10::o
145
145
return result;
146
146
}
147
147
148
- Tensor _cdist_backward (const Tensor& grad, const Tensor& x1, const Tensor& x2, const double p, const Tensor& cdist) {
148
+ Tensor _cdist_backward (const Tensor& grad, const Tensor& _x1, const Tensor& _x2, const double p, const Tensor& cdist) {
149
+ // Broadcasting might generate non-contiguous Tensors, so handle it before doing checks
150
+ int64_t c1 = _x1.size (-1 );
151
+ int64_t c2 = _x2.size (-1 );
152
+ int64_t r1 = _x1.size (-2 );
153
+ int64_t r2 = _x2.size (-2 );
154
+ auto dim1 = _x1.dim ();
155
+ auto dim2 = _x2.dim ();
156
+ IntArrayRef batch_tensor1 (_x1.sizes ().data (), dim1 - 2 );
157
+ IntArrayRef batch_tensor2 (_x2.sizes ().data (), dim2 - 2 );
158
+ std::vector<int64_t > expand_batch_portion = infer_size (batch_tensor1, batch_tensor2);
159
+ std::vector<int64_t > tensor1_expand_size (expand_batch_portion);
160
+ tensor1_expand_size.insert (tensor1_expand_size.end (), {r1, c1});
161
+ std::vector<int64_t > tensor2_expand_size (expand_batch_portion);
162
+ tensor2_expand_size.insert (tensor2_expand_size.end (), {r2, c2});
163
+
164
+ Tensor x1 = _x1;
165
+ if (tensor1_expand_size != x1.sizes ()) {
166
+ x1 = x1.expand (tensor1_expand_size).contiguous ();
167
+ }
168
+ Tensor x2 = _x2;
169
+ if (tensor2_expand_size != x2.sizes ()) {
170
+ x2 = x2.expand (tensor2_expand_size).contiguous ();
171
+ }
172
+
149
173
TORCH_CHECK (x1.is_contiguous (), " _cdist_backward requires X1 to be contiguous" );
150
174
TORCH_CHECK (x2.is_contiguous (), " _cdist_backward requires X2 to be contiguous" );
151
175
TORCH_CHECK (cdist.is_contiguous (), " _cdist_backward requires dist to be contiguous" );
@@ -156,13 +180,17 @@ Tensor _cdist_backward(const Tensor& grad, const Tensor& x1, const Tensor& x2, c
156
180
TORCH_CHECK (device1 == kCPU || device1 == kCUDA , " _cdist_backward only supports CPU and CUDA devices, X1 got: " , device1);
157
181
auto device2 = x2.device ().type ();
158
182
TORCH_CHECK (device2 == kCPU || device2 == kCUDA , " _cdist_backward only supports CPU and CUDA devices, X2 got: " , device2);
159
- IntArrayRef batch_tensor1 (x1.sizes ().data (), std::max<int64_t >(x1.dim () - 2 , 0 ));
160
- const int64_t batch_product = c10::multiply_integers (batch_tensor1);
183
+
184
+ // Compute the linearized batch size
185
+ const int64_t batch_product = c10::multiply_integers (expand_batch_portion);
186
+
161
187
Tensor grad_x1 =
162
- at::empty_like (x1, x1.options (), LEGACY_CONTIGUOUS_MEMORY_FORMAT)
163
- .view ({batch_product, n, m});
188
+ at::empty ({batch_product, n, m}, x1.options (), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
164
189
cdist_backward_stub (device1, grad_x1, grad, x1, x2, p, cdist);
165
- return grad_x1;
190
+
191
+ // Use x1.size() here and not the original size of _x1.size() as this gradient is not taking broadcasting into account
192
+ // Broadcasting will be handled automatically by the autograd engine
193
+ return grad_x1.view (x1.sizes ());
166
194
}
167
195
168
196
Tensor _pdist_forward (const Tensor& self, const double p) {
0 commit comments