Skip to content

Commit d5e0919

Browse files
brad-mengchifacebook-github-bot
authored andcommitted
jagged_dense_elementwise_mul add Meta and Autograd backend
Summary: Add Meta and Autograd backend for jagged_dense_elementwise_mul for dynamo and AOT autograd. A more proper way of doing Autograd to make inductor working for jagged ops. Also added Meta tensors so we're not mercy of arbitrary zero inputs fed into by default. Differential Revision: D41383597 fbshipit-source-id: 6b3c883980e036a9b130ee588c6186c0ca072680
1 parent 053ab70 commit d5e0919

File tree

5 files changed

+282
-66
lines changed

5 files changed

+282
-66
lines changed

fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,17 @@ at::Tensor jagged_dense_dense_elementwise_add_jagged_output_forward(
427427
const at::Tensor& y_0,
428428
const at::Tensor& y_1);
429429

430+
at::Tensor jagged_dense_elementwise_mul_forward(
431+
const at::Tensor& x_values,
432+
const std::vector<at::Tensor>& x_offsets,
433+
const at::Tensor& y);
434+
435+
std::tuple<at::Tensor, at::Tensor> jagged_dense_elementwise_mul_backward(
436+
const at::Tensor& grad_output,
437+
const std::vector<at::Tensor>& x_offsets,
438+
const at::Tensor& y,
439+
const at::Tensor& x_values);
440+
430441
///@ingroup sparse-data-cuda
431442
at::Tensor jagged_2d_to_dense_gpu(
432443
at::Tensor values,

fbgemm_gpu/src/jagged_tensor_ops.cu

Lines changed: 89 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1635,6 +1635,82 @@ void jagged_jagged_elementwise_dense_output_(
16351635
#undef INVOKE_KERNEL_WITH_DIM
16361636
}
16371637
1638+
Tensor jagged_dense_elementwise_mul_forward(
1639+
const Tensor& x_values,
1640+
const std::vector<Tensor>& x_offsets,
1641+
const Tensor& y) {
1642+
at::cuda::OptionalCUDAGuard device_guard;
1643+
device_guard.set_index(x_values.get_device());
1644+
1645+
Tensor output = at::empty_like(x_values);
1646+
1647+
AT_DISPATCH_SWITCH(
1648+
x_values.scalar_type(),
1649+
"jagged_dense_elementwise_mul_jagged_output_forward",
1650+
AT_DISPATCH_CASE(
1651+
at::ScalarType::Half,
1652+
[&] {
1653+
jagged_dense_elementwise_jagged_output_opt_<scalar_t>(
1654+
x_values,
1655+
x_offsets,
1656+
y,
1657+
output,
1658+
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
1659+
return x * y;
1660+
});
1661+
} // lambda
1662+
) // CASE
1663+
AT_DISPATCH_CASE_FLOATING_TYPES([&] {
1664+
jagged_dense_elementwise_jagged_output_<scalar_t>(
1665+
x_values,
1666+
x_offsets,
1667+
y,
1668+
output,
1669+
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
1670+
return x * y;
1671+
});
1672+
} // lambda
1673+
) // CASE_FLOATING_TYPES_AND
1674+
); // SWITCH
1675+
1676+
return output;
1677+
}
1678+
1679+
std::tuple<Tensor, Tensor> jagged_dense_elementwise_mul_backward(
1680+
const Tensor& grad_output,
1681+
const std::vector<Tensor>& x_offsets,
1682+
const Tensor& y,
1683+
const Tensor& x_values) {
1684+
at::cuda::OptionalCUDAGuard device_guard;
1685+
device_guard.set_index(grad_output.get_device());
1686+
1687+
Tensor x_values_grad = at::empty_like(grad_output);
1688+
Tensor y_grad = at::empty_like(y);
1689+
1690+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
1691+
x_values.scalar_type(), "jagged_scalars", [&] {
1692+
jagged_dense_elementwise_jagged_output_<scalar_t>(
1693+
grad_output,
1694+
x_offsets,
1695+
y,
1696+
x_values_grad,
1697+
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
1698+
return x * y;
1699+
});
1700+
1701+
jagged_jagged_elementwise_dense_output_<scalar_t>(
1702+
grad_output,
1703+
x_offsets,
1704+
x_values,
1705+
y_grad,
1706+
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
1707+
return x * y;
1708+
});
1709+
});
1710+
1711+
return {x_values_grad, y_grad};
1712+
}
1713+
16381714
class JaggedDenseMulGPUOp
16391715
: public torch::autograd::Function<JaggedDenseMulGPUOp> {
16401716
public:
@@ -1650,39 +1726,7 @@ class JaggedDenseMulGPUOp
16501726
tensors_to_save.push_back(y);
16511727
ctx->save_for_backward(tensors_to_save);
16521728
1653-
at::cuda::OptionalCUDAGuard device_guard;
1654-
device_guard.set_index(x_values.get_device());
1655-
1656-
Tensor output = at::empty_like(x_values);
1657-
1658-
AT_DISPATCH_SWITCH(
1659-
x_values.scalar_type(),
1660-
"jagged_dense_elementwise_mul_jagged_output_forward",
1661-
AT_DISPATCH_CASE(
1662-
at::ScalarType::Half,
1663-
[&] {
1664-
jagged_dense_elementwise_jagged_output_opt_<scalar_t>(
1665-
x_values,
1666-
x_offsets,
1667-
y,
1668-
output,
1669-
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
1670-
return x * y;
1671-
});
1672-
} // lambda
1673-
) // CASE
1674-
AT_DISPATCH_CASE_FLOATING_TYPES([&] {
1675-
jagged_dense_elementwise_jagged_output_<scalar_t>(
1676-
x_values,
1677-
x_offsets,
1678-
y,
1679-
output,
1680-
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
1681-
return x * y;
1682-
});
1683-
} // lambda
1684-
) // CASE_FLOATING_TYPES_AND
1685-
); // SWITCH
1729+
auto output = jagged_dense_elementwise_mul_forward(x_values, x_offsets, y);
16861730
16871731
return {output};
16881732
}
@@ -1698,34 +1742,13 @@ class JaggedDenseMulGPUOp
16981742
Tensor y = ctx->get_saved_variables().back();
16991743
TORCH_CHECK(grad_outputs.size() == 1);
17001744
1701-
at::cuda::OptionalCUDAGuard device_guard;
1702-
device_guard.set_index(grad_outputs[0].get_device());
1703-
1704-
Tensor x_values_grad = at::empty_like(grad_outputs[0]);
1705-
Tensor y_grad = at::empty_like(y);
1706-
1707-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
1708-
x_values.scalar_type(), "jagged_scalars", [&] {
1709-
jagged_dense_elementwise_jagged_output_<scalar_t>(
1710-
grad_outputs[0],
1711-
x_offsets,
1712-
y,
1713-
x_values_grad,
1714-
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
1715-
return x * y;
1716-
});
1717-
1718-
jagged_jagged_elementwise_dense_output_<scalar_t>(
1719-
grad_outputs[0],
1720-
x_offsets,
1721-
x_values,
1722-
y_grad,
1723-
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
1724-
return x * y;
1725-
});
1726-
});
1745+
auto outputs = jagged_dense_elementwise_mul_backward(
1746+
grad_outputs[0], x_offsets, y, x_values);
17271747
1728-
return {x_values_grad, y_grad, torch::autograd::Variable()};
1748+
return {
1749+
std::get<0>(outputs),
1750+
std::get<1>(outputs),
1751+
torch::autograd::Variable()};
17291752
}
17301753
};
17311754
@@ -3006,6 +3029,12 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
30063029
fbgemm_gpu::jagged_dense_dense_elementwise_add_jagged_output);
30073030
DISPATCH_TO_CUDA(
30083031
"jagged_dense_elementwise_mul", fbgemm_gpu::jagged_dense_elementwise_mul);
3032+
DISPATCH_TO_CUDA(
3033+
"jagged_dense_elementwise_mul_forward",
3034+
fbgemm_gpu::jagged_dense_elementwise_mul_forward);
3035+
DISPATCH_TO_CUDA(
3036+
"jagged_dense_elementwise_mul_backward",
3037+
fbgemm_gpu::jagged_dense_elementwise_mul_backward);
30093038
DISPATCH_TO_CUDA(
30103039
"batched_dense_vec_jagged_2d_mul",
30113040
fbgemm_gpu::batched_dense_vec_jagged_2d_mul);

fbgemm_gpu/src/jagged_tensor_ops_autograd.cpp

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ class JaggedToPaddedDenseAutogradOp
6565
}
6666
};
6767

68-
class JaggedDenseDenseAddJaggedOutputOp
69-
: public torch::autograd::Function<JaggedDenseDenseAddJaggedOutputOp> {
68+
class JaggedDenseDenseAddJaggedOutputAutogradOp
69+
: public torch::autograd::Function<
70+
JaggedDenseDenseAddJaggedOutputAutogradOp> {
7071
public:
7172
static torch::autograd::variable_list forward(
7273
torch::autograd::AutogradContext* ctx,
@@ -116,6 +117,56 @@ class JaggedDenseDenseAddJaggedOutputOp
116117
}
117118
};
118119

120+
class JaggedDenseMulAutogradOp
121+
: public torch::autograd::Function<JaggedDenseMulAutogradOp> {
122+
public:
123+
static torch::autograd::variable_list forward(
124+
torch::autograd::AutogradContext* ctx,
125+
const Tensor& x_values,
126+
const std::vector<Tensor>& x_offsets,
127+
const Tensor& y) {
128+
std::vector<Tensor> tensors_to_save;
129+
tensors_to_save.push_back(x_values);
130+
tensors_to_save.insert(
131+
tensors_to_save.end(), x_offsets.begin(), x_offsets.end());
132+
tensors_to_save.push_back(y);
133+
ctx->save_for_backward(tensors_to_save);
134+
135+
static auto op =
136+
c10::Dispatcher::singleton()
137+
.findSchemaOrThrow(
138+
"fbgemm::jagged_dense_elementwise_mul_forward", "")
139+
.typed<decltype(jagged_dense_elementwise_mul_forward)>();
140+
Tensor output = op.call(x_values, x_offsets, y);
141+
142+
return {output};
143+
}
144+
145+
static torch::autograd::variable_list backward(
146+
torch::autograd::AutogradContext* ctx,
147+
torch::autograd::variable_list grad_outputs) {
148+
const Tensor x_values = ctx->get_saved_variables().front();
149+
std::vector<Tensor> x_offsets;
150+
for (size_t i = 1; i < ctx->get_saved_variables().size() - 1; ++i) {
151+
x_offsets.push_back(ctx->get_saved_variables()[i]);
152+
}
153+
Tensor y = ctx->get_saved_variables().back();
154+
TORCH_CHECK(grad_outputs.size() == 1);
155+
156+
static auto op =
157+
c10::Dispatcher::singleton()
158+
.findSchemaOrThrow(
159+
"fbgemm::jagged_dense_elementwise_mul_backward", "")
160+
.typed<decltype(jagged_dense_elementwise_mul_backward)>();
161+
auto outputs = op.call(grad_outputs[0], x_offsets, y, x_values);
162+
163+
return {
164+
std::get<0>(outputs),
165+
torch::autograd::Variable(),
166+
std::get<1>(outputs)};
167+
}
168+
};
169+
119170
///@ingroup jagged-tensor-ops-autograd
120171
Tensor jagged_to_padded_dense_autograd(
121172
const Tensor& values,
@@ -158,12 +209,22 @@ jagged_dense_dense_elementwise_add_jagged_output_autograd(
158209
const std::vector<Tensor>& x_offsets,
159210
const Tensor& y_0,
160211
const Tensor& y_1) {
161-
auto sum_values = JaggedDenseDenseAddJaggedOutputOp::apply(
212+
auto sum_values = JaggedDenseDenseAddJaggedOutputAutogradOp::apply(
162213
x_values, x_offsets, y_0, y_1)[0];
163214

164215
return {sum_values, x_offsets};
165216
}
166217

218+
std::tuple<Tensor, std::vector<Tensor>> jagged_dense_elementwise_mul_autograd(
219+
const Tensor& x_values,
220+
const std::vector<Tensor>& x_offsets,
221+
const Tensor& y) {
222+
// Convert to jagged
223+
auto prod_values = JaggedDenseMulAutogradOp::apply(x_values, x_offsets, y)[0];
224+
225+
return {prod_values, x_offsets};
226+
}
227+
167228
} // namespace fbgemm_gpu
168229

169230
TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
@@ -178,4 +239,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
178239
"jagged_dense_dense_elementwise_add_jagged_output",
179240
TORCH_FN(fbgemm_gpu::
180241
jagged_dense_dense_elementwise_add_jagged_output_autograd));
242+
m.impl(
243+
"jagged_dense_elementwise_mul",
244+
TORCH_FN(fbgemm_gpu::jagged_dense_elementwise_mul_autograd));
181245
}

0 commit comments

Comments
 (0)