Skip to content

Commit

Permalink
[Operator] Add cat op (#213)
Browse files Browse the repository at this point in the history
  • Loading branch information
zfu82 authored and DuanYaQi committed Sep 20, 2024
1 parent 5062384 commit 83ae8ca
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 0 deletions.
46 changes: 46 additions & 0 deletions benchmark/test_special_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,49 @@ def hstack_args(dtype, batch, size):
sizes=SIZES,
)
bench.run()


def test_perf_cat():
def cat_args(dtype, batch, size):
inp1 = torch.randn([batch, size], dtype=dtype, device="cuda")
inp2 = torch.randn([batch, size], dtype=dtype, device="cuda")
return [[inp1, inp2]]

def cat_kwargs(dtype, batch, size):
return {"dim": 0}

bench = Benchmark(
op_name="cat",
torch_op=torch.cat,
arg_func=cat_args,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=cat_kwargs,
)
bench.run()


def test_perf_cat_int():
def cat_args(dtype, batch, size):
inp1 = torch.randint(
low=0, high=0x7FFF, size=[batch, size], dtype=dtype, device="cuda"
)
inp2 = torch.randint(
low=0, high=0x7FFF, size=[batch, size], dtype=dtype, device="cuda"
)
return [[inp1, inp2]]

def cat_kwargs(dtype, batch, size):
return {"dim": 0}

bench = Benchmark(
op_name="cat",
torch_op=torch.cat,
arg_func=cat_args,
dtypes=INT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=cat_kwargs,
)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def enable(lib=aten_lib):
lib.impl("masked_select", masked_select, "CUDA")
lib.impl("stack", stack, "CUDA")
lib.impl("hstack", hstack, "CUDA")
lib.impl("cat", cat, "CUDA")


class use_gems:
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .bitwise_not import bitwise_not
from .bitwise_or import bitwise_or_scalar, bitwise_or_scalar_tensor, bitwise_or_tensor
from .bmm import bmm
from .cat import cat
from .clamp import clamp, clamp_tensor
from .cos import cos
from .cross_entropy_loss import cross_entropy_loss
Expand Down Expand Up @@ -232,4 +233,5 @@
"masked_select",
"stack",
"hstack",
"cat",
]
61 changes: 61 additions & 0 deletions src/flag_gems/ops/cat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import itertools
import logging
from typing import List, Tuple, Union

import torch
import triton

from ..utils import pointwise_dynamic
from ..utils.tensor_wrapper import StridedBuffer


@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")])
@triton.jit
def copy_func(x):
return x


def cat(
A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0
) -> torch.Tensor:
logging.debug("GEMS CAT")

if len(A) == 0:
raise RuntimeError("torch.cat(): expected a non-empty list of Tensors")
if len(A) == 1:
return A[0]
# Same rank check
inp_shapes = [list(_.shape) for _ in A]
inp0_shape = inp_shapes[0]
for s in inp_shapes[1:]:
if len(s) != len(inp0_shape):
raise RuntimeError(
f"Tensors must have same number of dimensions: got {len(inp0_shape)} and {len(s)}"
)
# Same size check
for tensor_idx, inp_shape in enumerate(inp_shapes):
for idx, (common_length, length) in enumerate(zip(inp0_shape, inp_shape)):
if idx == dim:
continue
elif length != common_length:
raise RuntimeError(
f"Sizes of tensors must match except in dimension {dim}. "
f"Expected size {common_length} but got size {length} for tensor number "
f"{tensor_idx} in the list"
)

out_shape = list(inp0_shape)
out_shape[dim] = sum(s[dim] for s in inp_shapes)
out0 = torch.empty(out_shape, dtype=A[0].dtype, device=A[0].device)
out0_strides = out0.stride()
out0_offsets = list(
itertools.accumulate(
[s[dim] * out0_strides[dim] for s in inp_shapes[:-1]], initial=0
)
)

for a, out0_offset in zip(A, out0_offsets):
in_view = StridedBuffer(a, a.shape, a.stride())
out_view = StridedBuffer(out0, a.shape, out0.stride(), offset=out0_offset)
copy_func.instantiate(a.ndim)(in_view, out0=out_view)
return out0
48 changes: 48 additions & 0 deletions tests/test_special_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,51 @@ def test_exception_hstack(shape, dtype):
with pytest.raises(RuntimeError):
with flag_gems.use_gems():
_ = torch.hstack(inp)


CAT_SHAPES = [
[(1, 32), (8, 32)],
[(16, 128), (32, 128)],
[(1024, 1024), (1024, 1024)],
[(1, 1024, 256), (8, 1024, 256), (16, 1024, 256)],
[(16, 320, 15), (32, 320, 15), (64, 320, 15)],
[(16, 128, 64, 64), (16, 128, 64, 64), (24, 128, 64, 64), (32, 128, 64, 64)],
]


def gen_cat_shapes_dim(shapes):
results = []
for tensor_shapes in shapes:
assert all(
[len(s) == len(tensor_shapes[0]) for s in tensor_shapes]
), "All tensor rank must agree."
assert all(
[s[-1] == tensor_shapes[0][-1] for s in tensor_shapes]
), "All tensor must have same shape except cat dim."
rank = len(tensor_shapes[0])
results.append([tensor_shapes, 0])
for dim in range(1, rank):
results.append(
[[(s[dim], *s[1:dim], s[0], *s[dim + 1 :]) for s in tensor_shapes], dim]
)
return results


@pytest.mark.parametrize("shape, dim", gen_cat_shapes_dim(CAT_SHAPES))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES)
def test_accuracy_cat(shape, dim, dtype):
if dtype in FLOAT_DTYPES:
inp = [torch.randn(s, dtype=dtype, device="cuda") for s in shape]
else:
inp = [
torch.randint(low=0, high=0x7FFF, size=s, dtype=dtype, device="cuda").to(
dtype
)
for s in shape
]
ref_inp = [to_reference(_, True) for _ in inp]
ref_out = torch.cat(ref_inp, dim)

with flag_gems.use_gems():
res_out = torch.cat(inp, dim)
gems_assert_equal(res_out, ref_out)

0 comments on commit 83ae8ca

Please sign in to comment.