Skip to content

Commit

Permalink
[AMP OP&Test]fix index_select bf16 test (#51652)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxn12138 authored Mar 15, 2023
1 parent 6407672 commit e561644
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
3 changes: 0 additions & 3 deletions paddle/phi/kernels/gpu/index_select_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ template <typename T, typename IndexT>
__global__ void index_select_grad_cuda_kernel(const T* output_grad,
T* input_grad,
const IndexT* index,
int64_t nums,
int64_t N,
int64_t stride,
int64_t size,
Expand Down Expand Up @@ -104,7 +103,6 @@ void IndexSelectGradKernel(const Context& ctx,
<<<grid_dim, block_dim, 0, stream>>>(output_grad_data,
in_grad_data,
index_data,
index_nums,
out_nums,
stride,
size,
Expand All @@ -115,7 +113,6 @@ void IndexSelectGradKernel(const Context& ctx,
<<<grid_dim, block_dim, 0, stream>>>(output_grad_data,
in_grad_data,
index_data,
index_nums,
out_nums,
stride,
size,
Expand Down
8 changes: 6 additions & 2 deletions python/paddle/fluid/tests/unittests/test_index_select_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard

np.random.seed(1024)


class TestIndexSelectOp(OpTest):
def setUp(self):
Expand Down Expand Up @@ -119,7 +121,7 @@ def init_dtype_type(self):
self.dim = 1
self.x_type = np.uint16
self.index_type = np.int64
self.x_shape = (100, 4, 5)
self.x_shape = (20, 4, 5)
self.index_size = 100

def test_check_output(self):
Expand All @@ -137,10 +139,11 @@ def input_data(self):
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
]
)
).astype("float32")
self.data_index = np.array([0, 1, 1]).astype('int32')

def test_index_select_api(self):
paddle.enable_static()
self.input_data()

# case 1:
Expand Down Expand Up @@ -176,6 +179,7 @@ def test_index_select_api(self):
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)

def test_dygraph_api(self):
paddle.disable_static()
self.input_data()
# case 1:
with fluid.dygraph.guard():
Expand Down

0 comments on commit e561644

Please sign in to comment.