Skip to content

Commit

Permalink
fix index_select bug and modify the op test parameter (#197)
Browse files Browse the repository at this point in the history
Co-authored-by: Sheng Wang <sheng.wang@mthreads.com>
  • Loading branch information
Salv1a and Sheng Wang authored Sep 10, 2024
1 parent 9571933 commit 9337062
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/flag_gems/ops/index_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def index_select_kernel(
pid_y = tl.program_id(axis=1)
rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
rows_mask = rows_offsets < M
cols_offsets = pid_y + tl.arange(0, BLOCK_N)
cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N)
cols_mask = cols_offsets < N

block_mask = rows_mask and cols_mask
Expand Down
2 changes: 1 addition & 1 deletion tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ def test_accuracy_gather_out(out_shape, inp_shape, dim, dtype):
gems_assert_equal(res_out, ref_out)


@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
@pytest.mark.parametrize("shape", [(8192, 256 * i) for i in range(1, 10, 2)])
@pytest.mark.parametrize("dim", DIM_LIST)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_index_select(shape, dim, dtype):
Expand Down

0 comments on commit 9337062

Please sign in to comment.