Skip to content

Commit 78b5c10

Browse files
authored
[Sparse] Add sparse addmm kernel (dense+coo*dense->dense,dense+csr*dense->dense) (#44451)
1 parent a0bccd9 commit 78b5c10

File tree

12 files changed

+725
-2
lines changed

12 files changed

+725
-2
lines changed

paddle/phi/api/yaml/sparse_api.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,17 @@
266266
layout : x
267267
backward : values_grad
268268

269+
- api: addmm
270+
args : (Tensor input, Tensor x, Tensor y, float alpha=1.0, float beta=1.0)
271+
output : Tensor(out)
272+
kernel :
273+
func : addmm_csr_dense {dense, sparse_csr, dense -> dense},
274+
addmm_csr_csr {sparse_csr, sparse_csr, sparse_csr -> sparse_csr},
275+
addmm_coo_dense {dense, sparse_coo, dense -> dense},
276+
addmm_coo_coo {sparse_coo, sparse_coo, sparse_coo -> sparse_coo}
277+
layout : x
278+
backward: addmm_grad
279+
269280
- api: coalesce
270281
args : (Tensor x)
271282
output : Tensor(out)

paddle/phi/api/yaml/sparse_bw_api.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@
3030
func : add_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
3131
add_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}
3232

33+
- backward_api : addmm_grad
34+
forward : addmm(Tensor input, Tensor x, Tensor y, float alpha=1.0, float beta=1.0) -> Tensor(out)
35+
args : (Tensor input, Tensor x, Tensor y, Tensor out_grad, float alpha=1.0, float beta=1.0)
36+
output : Tensor(input_grad), Tensor(x_grad), Tensor(y_grad)
37+
kernel :
38+
func : addmm_csr_dense_grad {dense, sparse_csr, dense, dense -> dense, sparse_csr, dense},
39+
addmm_csr_csr_grad {sparse_csr, sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr, sparse_csr},
40+
addmm_coo_dense_grad {dense, sparse_coo, dense, dense -> dense, sparse_coo, dense},
41+
addmm_coo_coo_grad {sparse_coo, sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo, sparse_coo}
42+
3343
- backward_api : asin_grad
3444
forward : asin(Tensor x) -> Tensor(out)
3545
args : (Tensor x, Tensor out_grad)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/dense_tensor.h"
18+
#include "paddle/phi/core/sparse_coo_tensor.h"
19+
#include "paddle/phi/core/sparse_csr_tensor.h"
20+
21+
namespace phi {
22+
namespace sparse {
23+
24+
// TODO(zhouwei25): implement Backward of " COO + COO @ COO -> COO"
25+
template <typename T, typename Context>
26+
void AddmmCooCooGradKernel(const Context& dev_ctx,
27+
const SparseCooTensor& input,
28+
const SparseCooTensor& x,
29+
const SparseCooTensor& y,
30+
const SparseCooTensor& dout,
31+
float alpha,
32+
float beta,
33+
SparseCooTensor* dinput,
34+
SparseCooTensor* dx,
35+
SparseCooTensor* dy);
36+
37+
// Backward of "DENSE + COO @ DENSE -> DENSE"
38+
template <typename T, typename Context>
39+
void AddmmCooDenseGradKernel(const Context& dev_ctx,
40+
const DenseTensor& input,
41+
const SparseCooTensor& x,
42+
const DenseTensor& y,
43+
const DenseTensor& dout,
44+
float alpha,
45+
float beta,
46+
DenseTensor* dinput,
47+
SparseCooTensor* dx,
48+
DenseTensor* dy);
49+
50+
// TODO(zhouwei25): implement Backward of " CSR + CSR @ CSR -> CSR"
51+
template <typename T, typename Context>
52+
void AddmmCsrCsrGradKernel(const Context& dev_ctx,
53+
const SparseCsrTensor& input,
54+
const SparseCsrTensor& x,
55+
const SparseCsrTensor& y,
56+
const SparseCsrTensor& dout,
57+
float alpha,
58+
float beta,
59+
SparseCsrTensor* dinput,
60+
SparseCsrTensor* dx,
61+
SparseCsrTensor* dy);
62+
63+
/* Backward of "DENSE + CSR @ DENSE -> DENSE" */
64+
template <typename T, typename Context>
65+
void AddmmCsrDenseGradKernel(const Context& dev_ctx,
66+
const DenseTensor& input,
67+
const SparseCsrTensor& x,
68+
const DenseTensor& y,
69+
const DenseTensor& dout,
70+
float alpha,
71+
float beta,
72+
DenseTensor* dinput,
73+
SparseCsrTensor* dx,
74+
DenseTensor* dy);
75+
76+
} // namespace sparse
77+
} // namespace phi
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/dense_tensor.h"
18+
#include "paddle/phi/core/sparse_coo_tensor.h"
19+
#include "paddle/phi/core/sparse_csr_tensor.h"
20+
21+
namespace phi {
22+
namespace sparse {
23+
24+
// TODO(zhouwei25): implement " COO + COO @ COO -> COO"
25+
template <typename T, typename Context>
26+
void AddmmCooCooKernel(const Context& dev_ctx,
27+
const SparseCooTensor& input,
28+
const SparseCooTensor& x,
29+
const SparseCooTensor& y,
30+
float alpha,
31+
float beta,
32+
SparseCooTensor* out);
33+
34+
/* DENSE + COO @ DENSE -> DENSE */
35+
template <typename T, typename Context>
36+
void AddmmCooDenseKernel(const Context& dev_ctx,
37+
const DenseTensor& input,
38+
const SparseCooTensor& x,
39+
const DenseTensor& y,
40+
float alpha,
41+
float beta,
42+
DenseTensor* out);
43+
44+
// TODO(zhouwei25): implement " CSR + CSR @ CSR -> CSR"
45+
template <typename T, typename Context>
46+
void AddmmCsrCsrKernel(const Context& dev_ctx,
47+
const SparseCsrTensor& input,
48+
const SparseCsrTensor& x,
49+
const SparseCsrTensor& y,
50+
float alpha,
51+
float beta,
52+
SparseCsrTensor* out);
53+
54+
/* DENSE + CSR @ DENSE -> DENSE */
55+
template <typename T, typename Context>
56+
void AddmmCsrDenseKernel(const Context& dev_ctx,
57+
const DenseTensor& input,
58+
const SparseCsrTensor& x,
59+
const DenseTensor& y,
60+
float alpha,
61+
float beta,
62+
DenseTensor* out);
63+
64+
} // namespace sparse
65+
} // namespace phi
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/kernels/sparse/addmm_grad_kernel.h"
16+
17+
#include "paddle/phi/backends/cpu/cpu_context.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
20+
namespace phi {
21+
namespace sparse {
22+
23+
template <typename T, typename Context>
24+
void AddmmCooDenseGradKernel(const Context& dev_ctx,
25+
const DenseTensor& input,
26+
const SparseCooTensor& x,
27+
const DenseTensor& y,
28+
const DenseTensor& dout,
29+
float alpha,
30+
float beta,
31+
DenseTensor* dinput,
32+
SparseCooTensor* dx,
33+
DenseTensor* dy) {
34+
PADDLE_THROW(phi::errors::Unimplemented(
35+
"Not support CPU backward kernel of 'sparse.addmm' now."));
36+
}
37+
38+
template <typename T, typename Context>
39+
void AddmmCsrDenseGradKernel(const Context& dev_ctx,
40+
const DenseTensor& input,
41+
const SparseCsrTensor& x,
42+
const DenseTensor& y,
43+
const DenseTensor& dout,
44+
float alpha,
45+
float beta,
46+
DenseTensor* dinput,
47+
SparseCsrTensor* dx,
48+
DenseTensor* dy) {
49+
PADDLE_THROW(phi::errors::Unimplemented(
50+
"Not support CPU backward kernel of 'sparse.addmm' now."));
51+
}
52+
53+
} // namespace sparse
54+
} // namespace phi
55+
56+
PD_REGISTER_KERNEL(addmm_coo_dense_grad,
57+
CPU,
58+
ALL_LAYOUT,
59+
phi::sparse::AddmmCooDenseGradKernel,
60+
float,
61+
double) {
62+
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
63+
}
64+
65+
PD_REGISTER_KERNEL(addmm_csr_dense_grad,
66+
CPU,
67+
ALL_LAYOUT,
68+
phi::sparse::AddmmCsrDenseGradKernel,
69+
float,
70+
double) {
71+
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
72+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/kernels/sparse/addmm_kernel.h"
16+
17+
#include "paddle/phi/backends/cpu/cpu_context.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
namespace phi {
20+
namespace sparse {
21+
22+
/* DENSE + COO @ DENSE -> DENSE */
23+
template <typename T, typename Context>
24+
void AddmmCooDenseKernel(const Context& dev_ctx,
25+
const DenseTensor& input,
26+
const SparseCooTensor& x,
27+
const DenseTensor& y,
28+
float alpha,
29+
float beta,
30+
DenseTensor* out) {
31+
PADDLE_THROW(phi::errors::Unimplemented(
32+
"Not support CPU kernel of 'sparse.addmm' now."));
33+
}
34+
35+
/* DENSE + CSR @ DENSE -> DENSE */
36+
template <typename T, typename Context>
37+
void AddmmCsrDenseKernel(const Context& dev_ctx,
38+
const DenseTensor& input,
39+
const SparseCsrTensor& x,
40+
const DenseTensor& y,
41+
float alpha,
42+
float beta,
43+
DenseTensor* out) {
44+
PADDLE_THROW(phi::errors::Unimplemented(
45+
"Not support CPU kernel of 'sparse.addmm' now."));
46+
}
47+
48+
} // namespace sparse
49+
} // namespace phi
50+
51+
PD_REGISTER_KERNEL(addmm_coo_dense,
52+
CPU,
53+
ALL_LAYOUT,
54+
phi::sparse::AddmmCooDenseKernel,
55+
float,
56+
double) {
57+
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
58+
}
59+
60+
PD_REGISTER_KERNEL(addmm_csr_dense,
61+
CPU,
62+
ALL_LAYOUT,
63+
phi::sparse::AddmmCsrDenseKernel,
64+
float,
65+
double) {
66+
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
67+
}

paddle/phi/kernels/sparse/cpu/matmul_grad_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ void MatmulCsrDenseGradKernel(const Context& dev_ctx,
2929
SparseCsrTensor* dx,
3030
DenseTensor* dy) {
3131
PADDLE_THROW(phi::errors::Unimplemented(
32-
"Not support CPU backward kernel of Sparse Matmul now."));
32+
"Not support CPU backward kernel of 'sparse.matmul' now."));
3333
}
3434

3535
// TODO(zhouwei25): implement CPU kernel of " DENSE @ DENSE * CSR_MASK -> CSR"
@@ -41,7 +41,7 @@ void MaskedMatmulCsrGradKernel(const Context& dev_ctx,
4141
DenseTensor* dx,
4242
DenseTensor* dy) {
4343
PADDLE_THROW(phi::errors::Unimplemented(
44-
"Not support CPU backward kernel of Matmul Mask As Sparse now."));
44+
"Not support CPU backward kernel of 'sparse.masked_matmul' now."));
4545
}
4646

4747
} // namespace sparse

0 commit comments

Comments
 (0)