Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions backends/npu/kernels/squeeze_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx,
const phi::DenseTensor& x,
const std::vector<int>& axes,
phi::DenseTensor* out,
phi::DenseTensor* xshape) {
phi::DenseTensor* out) {
auto stream = dev_ctx.stream();

auto x_dims = x.dims();
Expand All @@ -97,6 +96,15 @@ void SqueezeKernel(const Context& dev_ctx,
out->Resize(out_dims);
}

template <typename T, typename Context>
void SqueezeWithXShapeKernel(const Context& dev_ctx,
const phi::DenseTensor& x,
const std::vector<int>& axes,
phi::DenseTensor* out,
phi::DenseTensor* xshape) {
custom_kernel::SqueezeKernel<T, Context>(dev_ctx, x, axes, out);
}

template <typename T, typename Context>
void SqueezeGradKernel(const Context& dev_ctx,
const phi::DenseTensor& xshape,
Expand Down Expand Up @@ -127,6 +135,19 @@ PD_REGISTER_PLUGIN_KERNEL(squeeze,
phi::dtype::float16,
double) {}

PD_REGISTER_PLUGIN_KERNEL(squeeze_with_xshape,
ascend,
ALL_LAYOUT,
custom_kernel::SqueezeWithXShapeKernel,
bool,
int,
uint8_t,
int8_t,
int64_t,
float,
phi::dtype::float16,
double) {}

PD_REGISTER_PLUGIN_KERNEL(squeeze_grad,
ascend,
ALL_LAYOUT,
Expand Down
28 changes: 26 additions & 2 deletions backends/npu/kernels/unsqueeze_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ template <typename T, typename Context>
void UnsqueezeNPUKernel(const Context& dev_ctx,
const phi::DenseTensor& x,
const phi::IntArray& axes,
phi::DenseTensor* out,
phi::DenseTensor* xshape) {
phi::DenseTensor* out) {
auto x_dims = x.dims();
auto out_dims = out->dims();

Expand All @@ -88,6 +87,15 @@ void UnsqueezeNPUKernel(const Context& dev_ctx,
out->Resize(out_dims); // copy will reset the dims.
}

template <typename T, typename Context>
void UnsqueezeWithXShapeNPUKernel(const Context& dev_ctx,
const phi::DenseTensor& x,
const phi::IntArray& axes,
phi::DenseTensor* out,
phi::DenseTensor* xshape) {
custom_kernel::UnsqueezeNPUKernel<T, Context>(dev_ctx, x, axes, out);
}

template <typename T, typename Context>
void UnsqueezeGradNPUKernel(const Context& dev_ctx,
const phi::DenseTensor& x_shape,
Expand Down Expand Up @@ -119,6 +127,22 @@ PD_REGISTER_PLUGIN_KERNEL(unsqueeze,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_PLUGIN_KERNEL(unsqueeze_with_xshape,
ascend,
ALL_LAYOUT,
custom_kernel::UnsqueezeWithXShapeNPUKernel,
float,
double,
phi::dtype::bfloat16,
bool,
int,
int16_t,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_PLUGIN_KERNEL(unsqueeze_grad,
ascend,
ALL_LAYOUT,
Expand Down