Skip to content

Commit

Permalink
fix swgilu matmul 0 size tensor bug (#71442)
Browse files Browse the repository at this point in the history
* fix swgilu matmul 0 size tensor bug

* fix bug

* disbale xpu 0 size test

* fix bug
  • Loading branch information
phlrain authored Mar 7, 2025
1 parent feaf1c0 commit 6673bae
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 0 deletions.
7 changes: 7 additions & 0 deletions paddle/phi/kernels/impl/matmul_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,13 @@ void MatmulGradKernel(const Context& dev_ctx,
bool transpose_y,
DenseTensor* dx,
DenseTensor* dy) {
if (x.numel() == 0) {
dev_ctx.template Alloc<T>(dx);
phi::FullKernel<T>(
dev_ctx, common::vectorize(y.dims()), 0.0, y.dtype(), dy);

return;
}
// get dims
std::vector<std::int64_t> x_dims = common::vectorize(x.dims());
std::vector<std::int64_t> y_dims = common::vectorize(y.dims());
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/kernels/swiglu_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ void SwiGLUGradKernel(const Context &ctx,
const DenseTensor &dz,
DenseTensor *dx,
DenseTensor *dy) {
if (x.numel() == 0) {
if (dx) {
ctx.template Alloc<T>(dx);
}
if (dy) {
ctx.template Alloc<T>(dy);
}
return;
}
const auto *x_ptr = x.data<T>();
const auto *dz_ptr = dz.data<T>();
auto *dx_ptr = dx ? ctx.template Alloc<T>(dx) : nullptr;
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/swiglu_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ void SwiGLUKernel(const Context &ctx,
const DenseTensor &x,
const paddle::optional<DenseTensor> &y,
DenseTensor *z) {
if (x.numel() == 0) {
ctx.template Alloc<T>(z);
return;
}
const auto *x_ptr = x.data<T>();
auto *z_ptr = ctx.template Alloc<T>(z);
const auto &dims = x.dims();
Expand Down
42 changes: 42 additions & 0 deletions test/legacy_test/test_matmul_0_size_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle
from paddle import _C_ops
from paddle.base import core


@unittest.skipIf(
not core.is_compiled_with_cuda(), "mamtul 0 size only with in cuda"
)
class TestMatmulDygraph(unittest.TestCase):
def test_matmul(self):
x = paddle.ones([0, 128], dtype="float32")
y = paddle.ones([128, 128], dtype="float32")
x.stop_gradient = False
y.stop_gradient = False
out = paddle.matmul(x, y)

dz = paddle.ones([0, 128], dtype="float32")

out = _C_ops.matmul_grad(x, y, dz, False, False)

self.assertEqual(out[0].shape, x.shape)
self.assertEqual(out[1].shape, y.shape)


if __name__ == "__main__":
unittest.main()
21 changes: 21 additions & 0 deletions test/legacy_test/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import paddle
import paddle.distributed as dist
import paddle.nn.functional as F
from paddle import _C_ops
from paddle.base import core
from paddle.distributed.auto_parallel.static.dist_attribute import (
DistTensorSpec,
TensorDistAttr,
Expand Down Expand Up @@ -276,5 +278,24 @@ def test_input_x_unshard_last_dim(self):
self.assertEqual(inferred_output_dist_attrs[0].dims_mapping, [0, -1])


@unittest.skipIf(
not core.is_compiled_with_cuda(), "mamtul 0 size only with in cuda"
)
class TestSwiglu0SizeDygraph(unittest.TestCase):
def test_swiglu(self):
x = paddle.ones([0, 128], dtype="float32")
y = paddle.ones([0, 128], dtype="float32")
x.stop_gradient = False
y.stop_gradient = False
out = fused_swiglu_impl(x, y)

dz = paddle.ones([0, 128], dtype="float32")

out = _C_ops.swiglu_grad(x, y, dz)

self.assertEqual(out[0].shape, x.shape)
self.assertEqual(out[1].shape, y.shape)


if __name__ == "__main__":
unittest.main()

0 comments on commit 6673bae

Please sign in to comment.