Skip to content

Commit

Permalink
[Operator] Add hstack op (#210)
Browse files Browse the repository at this point in the history
* [Operator] Add hstack op
  • Loading branch information
zfu82 authored Sep 13, 2024
1 parent 782e9e7 commit 0c7b989
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 0 deletions.
16 changes: 16 additions & 0 deletions benchmark/test_special_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,19 @@ def stack_args(dtype, batch, size):
sizes=SIZES,
)
bench.run()


def test_perf_hstack():
def hstack_args(dtype, batch, size):
inp = torch.randn(size=(batch, size), dtype=dtype, device="cuda")
return {(inp,) * 3}

bench = Benchmark(
op_name="hstack",
torch_op=torch.hstack,
arg_func=hstack_args,
dtypes=FLOAT_DTYPES,
batch=(512),
sizes=SIZES,
)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def enable(lib=aten_lib):
lib.impl("repeat", repeat, "CUDA")
lib.impl("masked_select", masked_select, "CUDA")
lib.impl("stack", stack, "CUDA")
lib.impl("hstack", hstack, "CUDA")


class use_gems:
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .gelu import gelu
from .groupnorm import group_norm
from .gt import gt, gt_scalar
from .hstack import hstack
from .index_select import index_select
from .isclose import allclose, isclose
from .isfinite import isfinite
Expand Down Expand Up @@ -231,4 +232,5 @@
"repeat",
"masked_select",
"stack",
"hstack",
]
71 changes: 71 additions & 0 deletions src/flag_gems/ops/hstack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import itertools
import logging
from typing import List, Tuple, Union

import torch
import triton

from ..utils import pointwise_dynamic
from ..utils.tensor_wrapper import StridedBuffer


@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")])
@triton.jit
def copy_func(x):
return x


def hstack(
tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]]
) -> torch.Tensor:
logging.debug("GEMS HSTACK")

if len(tensors) == 0:
raise RuntimeError("hstack expected a non-empty TensorList")

if tensors[0].ndim == 0:
tensors[0] = tensors[0].view(1)
inp0_shape = tensors[0].shape
out_shape = list(inp0_shape)
inp_shapes = [inp0_shape]

if len(inp0_shape) == 1:
dim = 0
else:
dim = 1

for tensor_num, tensor in enumerate(tensors[1:]):
if tensor.ndim == 0:
tensor = tensor.view(1)
if tensor.ndim != tensors[0].ndim:
raise RuntimeError(
f"Tensors must have same number of dimensions: got {tensors[0].ndim} and {tensor.ndim}"
)

inp_shape = tensor.shape
inp_shapes.append(inp_shape)

for i in range(len(inp_shape)):
if i != dim and inp_shape[i] != inp0_shape[i]:
raise RuntimeError(
f"Sizes of tensors must match except in dimension {dim}. \
Expected size {inp0_shape[i]} but got size {inp_shape[i]} \
for tensor number {tensor_num + 1} in the list."
)

out_shape[dim] = sum(s[dim] for s in inp_shapes)

out0 = torch.empty(out_shape, dtype=tensors[0].dtype, device=tensors[0].device)
out0_strides = out0.stride()
out0_offsets = list(
itertools.accumulate(
[s[dim] * out0_strides[dim] for s in inp_shapes[:-1]], initial=0
)
)

for a, out0_offset in zip(tensors, out0_offsets):
in_view = StridedBuffer(a, a.shape, a.stride())
out_view = StridedBuffer(out0, a.shape, out0.stride(), offset=out0_offset)
copy_func.instantiate(a.ndim)(in_view, out0=out_view)

return out0
51 changes: 51 additions & 0 deletions tests/test_special_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,3 +532,54 @@ def test_accuracy_stack(shape, dim, dtype):
with flag_gems.use_gems():
res_out = torch.stack(inp, dim)
gems_assert_equal(res_out, ref_out)


HSTACK_SHAPES = [
[(8,), (16,)],
[(16, 256), (16, 128)],
[(20, 320, 15), (20, 160, 15), (20, 80, 15)],
]


@pytest.mark.parametrize("shape", HSTACK_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES)
def test_accuracy_hstack(shape, dtype):
if dtype in FLOAT_DTYPES:
inp = [torch.randn(s, dtype=dtype, device="cuda") for s in shape]
else:
inp = [
torch.randint(low=0, high=0x7FFF, size=s, dtype=dtype, device="cuda").to(
dtype
)
for s in shape
]
ref_inp = [to_reference(_) for _ in inp]
ref_out = torch.hstack(ref_inp)

with flag_gems.use_gems():
res_out = torch.hstack(inp)
gems_assert_equal(res_out, ref_out)


HSTACK_EXCEPTION_SHAPES = [
[(16, 256), (16,)],
[(16, 256), (8, 128)],
]


@pytest.mark.parametrize("shape", HSTACK_EXCEPTION_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES)
def test_exception_hstack(shape, dtype):
if dtype in FLOAT_DTYPES:
inp = [torch.randn(s, dtype=dtype, device="cuda") for s in shape]
else:
inp = [
torch.randint(low=0, high=0x7FFF, size=s, dtype=dtype, device="cuda").to(
dtype
)
for s in shape
]

with pytest.raises(RuntimeError):
with flag_gems.use_gems():
_ = torch.hstack(inp)

0 comments on commit 0c7b989

Please sign in to comment.