From 9337062ced4954364e684c4489d96d2c4c87d44d Mon Sep 17 00:00:00 2001 From: Salv1a <772309295@qq.com> Date: Tue, 10 Sep 2024 09:39:06 +0800 Subject: [PATCH] fix index_select bug and modify the op test parameter (#197) Co-authored-by: Sheng Wang --- src/flag_gems/ops/index_select.py | 2 +- tests/test_reduction_ops.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/flag_gems/ops/index_select.py b/src/flag_gems/ops/index_select.py index aa342914..b78a764c 100644 --- a/src/flag_gems/ops/index_select.py +++ b/src/flag_gems/ops/index_select.py @@ -28,7 +28,7 @@ def index_select_kernel( pid_y = tl.program_id(axis=1) rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] rows_mask = rows_offsets < M - cols_offsets = pid_y + tl.arange(0, BLOCK_N) + cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N) cols_mask = cols_offsets < N block_mask = rows_mask and cols_mask diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 4787e155..421a7922 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -907,7 +907,7 @@ def test_accuracy_gather_out(out_shape, inp_shape, dim, dtype): gems_assert_equal(res_out, ref_out) -@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("shape", [(8192, 256 * i) for i in range(1, 10, 2)]) @pytest.mark.parametrize("dim", DIM_LIST) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_index_select(shape, dim, dtype):