Skip to content

Commit

Permalink
[Operator] Add repeat_interleave_self_int op
Browse files Browse the repository at this point in the history
  • Loading branch information
zfu82 committed Sep 14, 2024
1 parent 8d65c90 commit 1e8f1aa
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 0 deletions.
17 changes: 17 additions & 0 deletions benchmark/test_pointwise_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,3 +630,20 @@ def repeat_arg(dtype, batch, size):
sizes=SIZES,
)
bench.run()


def test_perf_repeat_interleave_self_int():
def repeat_interleave_self_int_arg(dtype, batch, size):
inp = torch.randn([batch, size], dtype=dtype, device="cuda")
repeats = 2
return inp, repeats

bench = Benchmark(
op_name="repeat_interleave_self_int",
torch_op=torch.repeat_interleave,
arg_func=repeat_interleave_self_int_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def enable(lib=aten_lib):
lib.impl("masked_select", masked_select, "CUDA")
lib.impl("stack", stack, "CUDA")
lib.impl("hstack", hstack, "CUDA")
lib.impl("repeat_interleave.self_int", repeat_interleave_self_int, "CUDA")


class use_gems:
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from .reciprocal import reciprocal
from .relu import relu
from .repeat import repeat
from .repeat_interleave import repeat_interleave_self_int
from .resolve_conj import resolve_conj
from .resolve_neg import resolve_neg
from .rms_norm import rms_norm
Expand Down Expand Up @@ -233,4 +234,5 @@
"masked_select",
"stack",
"hstack",
"repeat_interleave_self_int",
]
64 changes: 64 additions & 0 deletions src/flag_gems/ops/repeat_interleave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
import triton

from ..utils.pointwise_dynamic import pointwise_dynamic
from ..utils.shape_utils import c_contiguous_stride, volume
from ..utils.tensor_wrapper import StridedBuffer


@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")])
@triton.jit
def copy_func(x):
return x


def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None):
if dim is None:
nelems = volume(inp.shape)
inp_shape = [
nelems,
]
inp_stride = [
1,
]
output_shape = [
nelems,
]
dim = 0
else:
if (dim < -inp.ndim) or (dim >= inp.ndim):
raise IndexError(
"Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
-inp.ndim, inp.ndim - 1, dim
)
)
inp_shape = list(inp.shape)
inp_stride = list(inp.stride())
output_shape = list(inp.shape)

if dim < 0:
dim = dim + len(inp_shape)

output_shape[dim] *= repeats

if output_size is not None and output_size != output_shape[dim]:
raise RuntimeError(
"repeat_interleave: Invalid output_size, expected {} but got {}".format(
output_shape[dim], output_size
)
)

output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device)

if repeats == 0:
return output

in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :]
out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :]
out_view_stride = c_contiguous_stride(out_view_shape)

in_view = StridedBuffer(inp, out_view_shape, in_view_stride)
out_view = StridedBuffer(output, out_view_shape, out_view_stride)
ndim = len(out_view_shape)
copy_func.instantiate(ndim)(in_view, out0=out_view)
return output
18 changes: 18 additions & 0 deletions tests/test_binary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,3 +976,21 @@ def test_accuracy_allclose(shape, dtype, equal_nan, gen_nan):
ref_out = torch.allclose(ref_inp1, ref_inp2, rtol, atol, equal_nan=equal_nan)

assert res_out == ref_out


REPEAT_INTERLEAVE_REPEATS = [2]
REPEAT_INTERLEAVE_DIM = [-1, 0]


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dim", REPEAT_INTERLEAVE_DIM)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_repeat_interleave_self_int(shape, dim, dtype):
inp = torch.randn(shape, dtype=dtype, device="cuda")
repeats = 2
ref_inp = to_reference(inp)

ref_out = torch.repeat_interleave(ref_inp, repeats, dim)
with flag_gems.use_gems():
res_out = torch.repeat_interleave(ref_inp, repeats, dim)
gems_assert_equal(res_out, ref_out)

0 comments on commit 1e8f1aa

Please sign in to comment.