Skip to content

Commit

Permalink
[Operator] Add vstack op
Browse files Browse the repository at this point in the history
  • Loading branch information
yjl0101 authored and yjl0101 committed Aug 28, 2024
1 parent 2db4271 commit 442760d
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 0 deletions.
18 changes: 18 additions & 0 deletions benchmark/test_special_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,21 @@ def padding_kwargs(dtype, batch, size):
kwargs_func=padding_kwargs,
)
bench.run()


def test_perf_vstack():
def vstack_args(dtype, batch, size):
inp1 = torch.randn(size=(batch, size), dtype=dtype, device="cuda")
inp2 = torch.randn(size=(batch + 1, size), dtype=dtype, device="cuda")
inp3 = torch.randn(size=(batch + 2, size), dtype=dtype, device="cuda")
return [[inp1, inp2, inp3]]

bench = Benchmark(
op_name="vstack",
torch_op=torch.vstack,
arg_func=vstack_args,
dtypes=FLOAT_DTYPES,
batch=(512),
sizes=SIZES,
)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def enable(lib=aten_lib):
lib.impl("index_select", index_select, "CUDA")
lib.impl("masked_fill", masked_fill, "CUDA")
lib.impl("_unique2", _unique2, "CUDA")
lib.impl("vstack", vstack, "CUDA")


class use_gems:
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
from .unique import _unique2
from .var_mean import var_mean
from .vector_norm import vector_norm
from .vstack import vstack
from .where import where_scalar_other, where_scalar_self, where_self
from .zeros import zeros
from .zeros_like import zeros_like
Expand Down Expand Up @@ -205,4 +206,5 @@
"where_scalar_other",
"masked_fill",
"_unique2",
"vstack",
]
143 changes: 143 additions & 0 deletions src/flag_gems/ops/vstack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils import libentry


@libentry()
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": k}, num_warps=w)
for w in [4, 8, 16, 32]
for k in [512, 1024, 2048, 4096]
],
key=[
"max_tile_elems",
],
)
@triton.jit
def vstack_kernel(
itensor_ptr0,
itensor_ptr1,
itensor_ptr2,
itensor_ptr3,
output_ptr,
local_row0,
local_row1,
local_row2,
local_row3,
exc_row_offset0,
exc_row_offset1,
exc_row_offset2,
exc_row_offset3,
total_row_offset,
row_stride,
max_tile_elems,
BLOCK_SIZE: tl.constexpr,
):
pid_x = tl.program_id(axis=0)
tensor_idx = tl.program_id(axis=1)
col_idx = tl.arange(0, BLOCK_SIZE)

intensor_ptr = tl.where(tensor_idx == 0, itensor_ptr0, itensor_ptr1)
intensor_ptr = tl.where(tensor_idx == 2, itensor_ptr2, intensor_ptr)
intensor_ptr = tl.where(tensor_idx == 3, itensor_ptr3, intensor_ptr)
base_exc_row_idx = tl.where(tensor_idx == 0, exc_row_offset0, exc_row_offset1)
base_exc_row_idx = tl.where(tensor_idx == 2, exc_row_offset2, base_exc_row_idx)
base_exc_row_idx = tl.where(tensor_idx == 3, exc_row_offset3, base_exc_row_idx)
local_row = tl.where(tensor_idx == 0, local_row0, local_row1)
local_row = tl.where(tensor_idx == 2, local_row2, local_row)
local_row = tl.where(tensor_idx == 3, local_row3, local_row)

grid_stride_x = tl.num_programs(axis=0) * BLOCK_SIZE
start_idx = pid_x * BLOCK_SIZE
end_idx = (local_row * row_stride).to(tl.int64)
for r in range(start_idx, end_idx, grid_stride_x):
idx = r + col_idx
offset_mask = idx < end_idx
in_offset = intensor_ptr + idx
row_stride_offset = (total_row_offset + base_exc_row_idx) * row_stride
out_offset = output_ptr + row_stride_offset + idx
out = tl.load(in_offset, mask=offset_mask)
tl.store(out_offset, out, mask=offset_mask)


def vstack(tensors: list[torch.Tensor]):
logging.debug("GEMS VSTACK")

num_tensors = len(tensors)
assert num_tensors > 0

# Ensure all tensors are on the same device and have the same dtype
device = tensors[0].device
dtype = tensors[0].dtype
for tensor in tensors:
assert (
tensor.device == device
and tensor.dtype == dtype
and tensors[0].shape[1:] == tensor.shape[1:]
)

c_tensors = [t.contiguous() for t in tensors]
# Calculate the output shape
total_rows = sum(tensor.shape[0] for tensor in c_tensors)
output_shape = list(c_tensors[0].shape)
output_shape[0] = total_rows
output = torch.empty(output_shape, device=device, dtype=dtype)
row_stride = c_tensors[0].stride(0)

outer_iters = triton.cdiv(num_tensors, 4)
total_row_offset = 0
for i in range(outer_iters):
max_rows = 1
itensors = []
exclusive_row = []
local_row = []
array_row_offset = 0
scheduled_num_tensors = 0
for j in range(4):
tensor_idx = i * 4 + j
if tensor_idx < num_tensors:
scheduled_num_tensors += 1
itensors.append(c_tensors[tensor_idx])
local_row.append(c_tensors[tensor_idx].shape[0])
exclusive_row.append(array_row_offset)
array_row_offset += c_tensors[tensor_idx].shape[0]
max_rows = max(max_rows, c_tensors[tensor_idx].shape[0])
else:
empty_tensor = torch.empty(
0, dtype=c_tensors[0].dtype, device=c_tensors[0].device
)
itensors.append(empty_tensor)
local_row.append(local_row[-1])
exclusive_row.append(exclusive_row[-1])
max_tile_elems = max_rows * row_stride
grid = lambda META: (
triton.cdiv(max_tile_elems, META["BLOCK_SIZE"]),
scheduled_num_tensors,
)
# Launch the kernel
with torch.cuda.device(c_tensors[0].device):
vstack_kernel[grid](
itensors[0],
itensors[1],
itensors[2],
itensors[3],
output,
local_row[0],
local_row[1],
local_row[2],
local_row[3],
exclusive_row[0],
exclusive_row[1],
exclusive_row[2],
exclusive_row[3],
total_row_offset,
row_stride,
max_tile_elems,
)
total_row_offset += array_row_offset
return output
13 changes: 13 additions & 0 deletions tests/accuracy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@
REDUCTION_SHAPES = [(4096, 256 * i) for i in range(1, 10, 2)]
MNK_SHAPES = [15, 160, 1024]

VSTACK_SHAPES = [
[(3, 33), (7, 33)],
[(13, 3, 333), (17, 3, 333), (7, 3, 333)],
[(13, 3, 128, 5), (16, 3, 128, 5), (7, 3, 128, 5), (4, 3, 128, 5)],
[
(13, 3, 64, 5, 2),
(16, 3, 64, 5, 2),
(7, 3, 64, 5, 2),
(4, 3, 64, 5, 2),
(1, 3, 64, 5, 2),
],
]

DIM_POINTWISE_SHAPES = [
(1024, 1024, 1),
(16, 1024, 256),
Expand Down
20 changes: 20 additions & 0 deletions tests/test_special_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
INT_DTYPES,
POINTWISE_SHAPES,
RESOLUTION,
VSTACK_SHAPES,
gems_assert_close,
gems_assert_equal,
to_reference,
Expand Down Expand Up @@ -343,6 +344,25 @@ def test_accuracy_unique(shape, dtype, sorted, return_inverse, return_counts):
return_counts=return_counts,
)
assert res_out.numel() == ref_out.numel()


@pytest.mark.parametrize("shape", VSTACK_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES)
def test_accuracy_vstack(shape, dtype):
if dtype in FLOAT_DTYPES:
inp = [torch.randn(s, dtype=dtype, device="cuda") for s in shape]
else:
inp = [
torch.randint(low=0, high=0x7FFF, size=s, dtype=dtype, device="cuda").to(
dtype
)
for s in shape
]
ref_inp = [to_reference(_, True) for _ in inp]
ref_out = torch.vstack(ref_inp)

with flag_gems.use_gems():
res_out = torch.vstack(ref_inp)
gems_assert_equal(res_out, ref_out)


Expand Down

0 comments on commit 442760d

Please sign in to comment.