@@ -1635,6 +1635,82 @@ void jagged_jagged_elementwise_dense_output_(
1635
1635
#undef INVOKE_KERNEL_WITH_DIM
1636
1636
}
1637
1637
1638
+ Tensor jagged_dense_elementwise_mul_forward (
1639
+ const Tensor& x_values,
1640
+ const std::vector<Tensor>& x_offsets,
1641
+ const Tensor& y) {
1642
+ at::cuda::OptionalCUDAGuard device_guard;
1643
+ device_guard.set_index (x_values.get_device ());
1644
+
1645
+ Tensor output = at::empty_like (x_values);
1646
+
1647
+ AT_DISPATCH_SWITCH (
1648
+ x_values.scalar_type (),
1649
+ " jagged_dense_elementwise_mul_jagged_output_forward" ,
1650
+ AT_DISPATCH_CASE (
1651
+ at::ScalarType::Half,
1652
+ [&] {
1653
+ jagged_dense_elementwise_jagged_output_opt_<scalar_t >(
1654
+ x_values,
1655
+ x_offsets,
1656
+ y,
1657
+ output,
1658
+ [] __device__ (scalar_t x, scalar_t y) -> scalar_t {
1659
+ return x * y;
1660
+ });
1661
+ } // lambda
1662
+ ) // CASE
1663
+ AT_DISPATCH_CASE_FLOATING_TYPES ([&] {
1664
+ jagged_dense_elementwise_jagged_output_<scalar_t >(
1665
+ x_values,
1666
+ x_offsets,
1667
+ y,
1668
+ output,
1669
+ [] __device__ (scalar_t x, scalar_t y) -> scalar_t {
1670
+ return x * y;
1671
+ });
1672
+ } // lambda
1673
+ ) // CASE_FLOATING_TYPES_AND
1674
+ ); // SWITCH
1675
+
1676
+ return output;
1677
+ }
1678
+
1679
+ std::tuple<Tensor, Tensor> jagged_dense_elementwise_mul_backward (
1680
+ const Tensor& grad_output,
1681
+ const std::vector<Tensor>& x_offsets,
1682
+ const Tensor& y,
1683
+ const Tensor& x_values) {
1684
+ at::cuda::OptionalCUDAGuard device_guard;
1685
+ device_guard.set_index (grad_output.get_device ());
1686
+
1687
+ Tensor x_values_grad = at::empty_like (grad_output);
1688
+ Tensor y_grad = at::empty_like (y);
1689
+
1690
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF (
1691
+ x_values.scalar_type (), " jagged_scalars" , [&] {
1692
+ jagged_dense_elementwise_jagged_output_<scalar_t >(
1693
+ grad_output,
1694
+ x_offsets,
1695
+ y,
1696
+ x_values_grad,
1697
+ [] __device__ (scalar_t x, scalar_t y) -> scalar_t {
1698
+ return x * y;
1699
+ });
1700
+
1701
+ jagged_jagged_elementwise_dense_output_<scalar_t >(
1702
+ grad_output,
1703
+ x_offsets,
1704
+ x_values,
1705
+ y_grad,
1706
+ [] __device__ (scalar_t x, scalar_t y) -> scalar_t {
1707
+ return x * y;
1708
+ });
1709
+ });
1710
+
1711
+ return {x_values_grad, y_grad};
1712
+ }
1713
+
1638
1714
class JaggedDenseMulGPUOp
1639
1715
: public torch::autograd::Function<JaggedDenseMulGPUOp> {
1640
1716
public:
@@ -1650,39 +1726,7 @@ class JaggedDenseMulGPUOp
1650
1726
tensors_to_save.push_back (y);
1651
1727
ctx->save_for_backward (tensors_to_save);
1652
1728
1653
- at::cuda::OptionalCUDAGuard device_guard;
1654
- device_guard.set_index (x_values.get_device ());
1655
-
1656
- Tensor output = at::empty_like (x_values);
1657
-
1658
- AT_DISPATCH_SWITCH (
1659
- x_values.scalar_type (),
1660
- " jagged_dense_elementwise_mul_jagged_output_forward" ,
1661
- AT_DISPATCH_CASE (
1662
- at::ScalarType::Half,
1663
- [&] {
1664
- jagged_dense_elementwise_jagged_output_opt_<scalar_t >(
1665
- x_values,
1666
- x_offsets,
1667
- y,
1668
- output,
1669
- [] __device__ (scalar_t x, scalar_t y) -> scalar_t {
1670
- return x * y;
1671
- });
1672
- } // lambda
1673
- ) // CASE
1674
- AT_DISPATCH_CASE_FLOATING_TYPES ([&] {
1675
- jagged_dense_elementwise_jagged_output_<scalar_t >(
1676
- x_values,
1677
- x_offsets,
1678
- y,
1679
- output,
1680
- [] __device__ (scalar_t x, scalar_t y) -> scalar_t {
1681
- return x * y;
1682
- });
1683
- } // lambda
1684
- ) // CASE_FLOATING_TYPES_AND
1685
- ); // SWITCH
1729
+ auto output = jagged_dense_elementwise_mul_forward (x_values, x_offsets, y);
1686
1730
1687
1731
return {output};
1688
1732
}
@@ -1698,34 +1742,13 @@ class JaggedDenseMulGPUOp
1698
1742
Tensor y = ctx->get_saved_variables ().back ();
1699
1743
TORCH_CHECK (grad_outputs.size () == 1 );
1700
1744
1701
- at::cuda::OptionalCUDAGuard device_guard;
1702
- device_guard.set_index (grad_outputs[0 ].get_device ());
1703
-
1704
- Tensor x_values_grad = at::empty_like (grad_outputs[0 ]);
1705
- Tensor y_grad = at::empty_like (y);
1706
-
1707
- AT_DISPATCH_FLOATING_TYPES_AND_HALF (
1708
- x_values.scalar_type (), " jagged_scalars" , [&] {
1709
- jagged_dense_elementwise_jagged_output_<scalar_t >(
1710
- grad_outputs[0 ],
1711
- x_offsets,
1712
- y,
1713
- x_values_grad,
1714
- [] __device__ (scalar_t x, scalar_t y) -> scalar_t {
1715
- return x * y;
1716
- });
1717
-
1718
- jagged_jagged_elementwise_dense_output_<scalar_t >(
1719
- grad_outputs[0 ],
1720
- x_offsets,
1721
- x_values,
1722
- y_grad,
1723
- [] __device__ (scalar_t x, scalar_t y) -> scalar_t {
1724
- return x * y;
1725
- });
1726
- });
1745
+ auto outputs = jagged_dense_elementwise_mul_backward (
1746
+ grad_outputs[0 ], x_offsets, y, x_values);
1727
1747
1728
- return {x_values_grad, y_grad, torch::autograd::Variable ()};
1748
+ return {
1749
+ std::get<0 >(outputs),
1750
+ std::get<1 >(outputs),
1751
+ torch::autograd::Variable ()};
1729
1752
}
1730
1753
};
1731
1754
@@ -3006,6 +3029,12 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
3006
3029
fbgemm_gpu::jagged_dense_dense_elementwise_add_jagged_output);
3007
3030
DISPATCH_TO_CUDA (
3008
3031
" jagged_dense_elementwise_mul" , fbgemm_gpu::jagged_dense_elementwise_mul);
3032
+ DISPATCH_TO_CUDA (
3033
+ " jagged_dense_elementwise_mul_forward" ,
3034
+ fbgemm_gpu::jagged_dense_elementwise_mul_forward);
3035
+ DISPATCH_TO_CUDA (
3036
+ " jagged_dense_elementwise_mul_backward" ,
3037
+ fbgemm_gpu::jagged_dense_elementwise_mul_backward);
3009
3038
DISPATCH_TO_CUDA (
3010
3039
" batched_dense_vec_jagged_2d_mul" ,
3011
3040
fbgemm_gpu::batched_dense_vec_jagged_2d_mul);
0 commit comments