Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion paddle/fluid/eager/grad_node_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,32 @@ void GradNodeBase::SetGradOutMeta(const paddle::Tensor& fwd_in,
meta.SetDistAttr(dist_attr);
meta.SetDistTensorGlobalDims(dist_tensor->dims());
SetIsRunAutoParallel(true);
} else if (phi::SparseCsrTensor::classof(fwd_in.impl().get())) {
phi::SparseCsrTensor* sparse_tensor =
static_cast<phi::SparseCsrTensor*>(fwd_in.impl().get());
const phi::DenseTensor dense_tensor =
static_cast<const phi::DenseTensor>(sparse_tensor->values());
PADDLE_ENFORCE_NE(
dense_tensor.dtype(),
phi::DataType::UNDEFINED,
paddle::platform::errors::Fatal("Attempting to copy DenseTensorMeta "
"with phi::DataType::UNDEFINED,"
"which is illegal."));
meta.SetTensorMeta(dense_tensor.meta());
meta.SetPlace(fwd_in.place());
} else if (phi::SparseCooTensor::classof(fwd_in.impl().get())) {
phi::SparseCooTensor* sparse_tensor =
static_cast<phi::SparseCooTensor*>(fwd_in.impl().get());
const phi::DenseTensor dense_tensor =
static_cast<const phi::DenseTensor>(sparse_tensor->values());
PADDLE_ENFORCE_NE(
dense_tensor.dtype(),
phi::DataType::UNDEFINED,
paddle::platform::errors::Fatal("Attempting to copy DenseTensorMeta "
"with phi::DataType::UNDEFINED,"
"which is illegal."));
meta.SetTensorMeta(dense_tensor.meta());
meta.SetPlace(fwd_in.place());
} else {
VLOG(7)
<< "Unable to initialize the DenseTensorMeta of GradSlotMeta with "
Expand Down Expand Up @@ -501,6 +527,32 @@ void GradNodeBase::SetGradOutMeta(const std::vector<paddle::Tensor>& fwd_in,
meta.SetTensorMeta(dense_tensor.meta());
meta.SetPlace(fwd_in_tensor.place());
}
} else if (phi::SparseCsrTensor::classof(fwd_in_tensor.impl().get())) {
phi::SparseCsrTensor* sparse_tensor =
static_cast<phi::SparseCsrTensor*>(fwd_in_tensor.impl().get());
const phi::DenseTensor dense_tensor =
static_cast<const phi::DenseTensor>(sparse_tensor->values());
PADDLE_ENFORCE_NE(
dense_tensor.dtype(),
phi::DataType::UNDEFINED,
paddle::platform::errors::Fatal("Attempting to copy DenseTensorMeta "
"with phi::DataType::UNDEFINED,"
"which is illegal."));
meta.SetTensorMeta(dense_tensor.meta());
meta.SetPlace(fwd_in_tensor.place());
} else if (phi::SparseCooTensor::classof(fwd_in_tensor.impl().get())) {
phi::SparseCooTensor* sparse_tensor =
static_cast<phi::SparseCooTensor*>(fwd_in_tensor.impl().get());
const phi::DenseTensor dense_tensor =
static_cast<const phi::DenseTensor>(sparse_tensor->values());
PADDLE_ENFORCE_NE(
dense_tensor.dtype(),
phi::DataType::UNDEFINED,
paddle::platform::errors::Fatal("Attempting to copy DenseTensorMeta "
"with phi::DataType::UNDEFINED,"
"which is illegal."));
meta.SetTensorMeta(dense_tensor.meta());
meta.SetPlace(fwd_in_tensor.place());
} else {
VLOG(7)
<< "Unable to initialize the DenseTensorMeta of GradSlotMeta with "
Expand Down Expand Up @@ -668,7 +720,7 @@ void GradNodeBase::HandleComplexGradToRealGrad(
for (size_t rank_id = 0; rank_id < slot_out_grads.size(); rank_id++) {
if (bwd_out_meta_[slot_id].size() == 0) continue;
const GradSlotMeta& slot_meta = bwd_out_meta_[slot_id][rank_id];

VLOG(6) << "out_grad" << slot_out_grads[rank_id].dtype();
PADDLE_ENFORCE(
slot_meta.HasTensorMeta() > 0,
paddle::platform::errors::Fatal(
Expand Down
36 changes: 27 additions & 9 deletions paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,9 @@ PD_REGISTER_KERNEL(dense_to_coo,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(csr_to_coo,
CPU,
Expand All @@ -342,7 +344,9 @@ PD_REGISTER_KERNEL(csr_to_coo,
int16_t,
int,
int64_t,
bool) {}
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(coo_to_csr,
CPU,
Expand All @@ -356,7 +360,9 @@ PD_REGISTER_KERNEL(coo_to_csr,
int16_t,
int,
int64_t,
bool) {}
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(dense_to_csr,
CPU,
Expand All @@ -369,7 +375,9 @@ PD_REGISTER_KERNEL(dense_to_csr,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(coo_to_dense,
CPU,
Expand All @@ -383,7 +391,9 @@ PD_REGISTER_KERNEL(coo_to_dense,
int16_t,
int,
int64_t,
bool) {}
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(csr_to_dense,
CPU,
Expand All @@ -397,7 +407,9 @@ PD_REGISTER_KERNEL(csr_to_dense,
int16_t,
int,
int64_t,
bool) {}
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(values_coo,
CPU,
Expand All @@ -411,7 +423,9 @@ PD_REGISTER_KERNEL(values_coo,
int16_t,
int,
int64_t,
bool) {
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

Expand Down Expand Up @@ -442,7 +456,9 @@ PD_REGISTER_KERNEL(values_csr,
int16_t,
int,
int64_t,
bool) {
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}

Expand All @@ -456,4 +472,6 @@ PD_REGISTER_KERNEL(sparse_coo_tensor,
uint8_t,
int16_t,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
26 changes: 25 additions & 1 deletion paddle/phi/kernels/sparse/cpu/unary_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,33 @@
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
}

#define PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL_WITH_COMPLEX(name, prefix) \
PD_REGISTER_KERNEL(name##_coo_grad, \
CPU, \
ALL_LAYOUT, \
phi::sparse::prefix##CooGradKernel, \
float, \
double, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \
} \
\
PD_REGISTER_KERNEL(name##_csr_grad, \
CPU, \
ALL_LAYOUT, \
phi::sparse::prefix##CsrGradKernel, \
float, \
double, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
}

PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(sin, Sin)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(tan, Tan)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(asin, Asin)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(atan, Atan)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(sinh, Sinh)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(tanh, Tanh)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(asinh, Asinh)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(atanh, Atanh)
Expand All @@ -55,6 +77,8 @@ PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(expm1, Expm1)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(relu6, Relu6)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(leaky_relu, LeakyRelu)

PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL_WITH_COMPLEX(sinh, Sinh)

PD_REGISTER_KERNEL(cast_coo_grad,
CPU,
ALL_LAYOUT,
Expand Down
26 changes: 25 additions & 1 deletion paddle/phi/kernels/sparse/cpu/unary_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,33 @@ void DivScalarCsrKernel(const Context& dev_ctx,
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
}

#define PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL_WITH_COMPLEX(name, prefix) \
PD_REGISTER_KERNEL(name##_coo, \
CPU, \
ALL_LAYOUT, \
phi::sparse::prefix##CooKernel, \
float, \
double, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \
} \
\
PD_REGISTER_KERNEL(name##_csr, \
CPU, \
ALL_LAYOUT, \
phi::sparse::prefix##CsrKernel, \
float, \
double, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>) { \
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
}

PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(sin, Sin)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(tan, Tan)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(asin, Asin)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(atan, Atan)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(sinh, Sinh)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(tanh, Tanh)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(asinh, Asinh)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(atanh, Atanh)
Expand All @@ -97,6 +119,8 @@ PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(expm1, Expm1)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu6, Relu6)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(leaky_relu, LeakyRelu)

PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL_WITH_COMPLEX(sinh, Sinh)

PD_REGISTER_KERNEL(divide_scalar_coo,
CPU,
ALL_LAYOUT,
Expand Down
16 changes: 12 additions & 4 deletions paddle/phi/kernels/sparse/empty_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ PD_REGISTER_KERNEL(empty_like_coo,
int16_t,
int,
int64_t,
bool) {
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

Expand All @@ -78,7 +80,9 @@ PD_REGISTER_KERNEL(empty_like_csr,
int16_t,
int,
int64_t,
bool) {
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}

Expand All @@ -95,7 +99,9 @@ PD_REGISTER_KERNEL(empty_like_coo,
int16_t,
int,
int64_t,
bool) {
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

Expand All @@ -111,7 +117,9 @@ PD_REGISTER_KERNEL(empty_like_csr,
int16_t,
int,
int64_t,
bool) {
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}
#endif
36 changes: 27 additions & 9 deletions paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,9 @@ PD_REGISTER_KERNEL(dense_to_coo,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(csr_to_coo,
GPU,
Expand All @@ -603,7 +605,9 @@ PD_REGISTER_KERNEL(csr_to_coo,
int16_t,
int,
int64_t,
bool) {}
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(coo_to_csr,
GPU,
Expand All @@ -617,7 +621,9 @@ PD_REGISTER_KERNEL(coo_to_csr,
int16_t,
int,
int64_t,
bool) {}
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(dense_to_csr,
GPU,
Expand All @@ -630,7 +636,9 @@ PD_REGISTER_KERNEL(dense_to_csr,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(coo_to_dense,
GPU,
Expand All @@ -644,7 +652,9 @@ PD_REGISTER_KERNEL(coo_to_dense,
int16_t,
int,
int64_t,
bool) {}
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(csr_to_dense,
GPU,
Expand All @@ -658,7 +668,9 @@ PD_REGISTER_KERNEL(csr_to_dense,
int16_t,
int,
int64_t,
bool) {}
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(values_coo,
GPU,
Expand All @@ -672,7 +684,9 @@ PD_REGISTER_KERNEL(values_coo,
int16_t,
int,
int64_t,
bool) {
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

Expand All @@ -688,7 +702,9 @@ PD_REGISTER_KERNEL(values_csr,
int16_t,
int,
int64_t,
bool) {
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}

Expand Down Expand Up @@ -717,4 +733,6 @@ PD_REGISTER_KERNEL(sparse_coo_tensor,
uint8_t,
int16_t,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
Loading