Skip to content

Commit

Permalink
[TVMScript][Parser] Add more warp-level builtins and Range (apache#…
Browse files Browse the repository at this point in the history
…14279)

# Motivation
Several builtins "tvm_storage_sync", "tvm_warp_shuffle", "tvm_warp_shuffle_up", "tvm_warp_shuffle_down", "tvm_warp_activemask" and `Range` will appear in TVMScript printer but are missing in TVMScript parser. This PR fix the behavior.
  • Loading branch information
yzh119 authored Mar 12, 2023
1 parent e3c8f2b commit caf6b03
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 13 deletions.
51 changes: 39 additions & 12 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
import numpy as np # type: ignore

from tvm import tir
from tvm.ir import Range, Type
from tvm import ir
from tvm.ir import Type
from tvm.ir.base import deprecated
from tvm.runtime import String, convert, ndarray
from tvm.target import Target
Expand Down Expand Up @@ -496,7 +497,7 @@ def alloc_buffer(
)


def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range:
def _as_range(dom: Union[ir.Range, List[PrimExpr]]) -> ir.Range:
"""The range constructor.
Parameters
Expand All @@ -509,21 +510,21 @@ def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range:
res : Range
The Range.
"""
if isinstance(dom, Range):
if isinstance(dom, ir.Range):
return dom
if isinstance(dom, (list, tuple)):
return Range(dom[0], dom[1])
return ir.Range(dom[0], dom[1])
if hasattr(dom, "dtype"):
return Range(IntImm(dom.dtype, 0), dom)
return Range(0, dom)
return ir.Range(IntImm(dom.dtype, 0), dom)
return ir.Range(0, dom)


class axis: # pylint: disable=invalid-name
"""The axis class"""

@staticmethod
def spatial(
dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr,
dtype: str = "int32",
) -> Var:
Expand Down Expand Up @@ -551,7 +552,7 @@ def spatial(

@staticmethod
def reduce(
dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr,
dtype: str = "int32",
) -> Var:
Expand Down Expand Up @@ -579,7 +580,7 @@ def reduce(

@staticmethod
def scan(
dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr,
dtype: str = "int32",
) -> Var:
Expand Down Expand Up @@ -607,7 +608,7 @@ def scan(

@staticmethod
def opaque(
dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr,
dtype: str = "int32",
) -> Var:
Expand Down Expand Up @@ -1288,7 +1289,7 @@ def buffer_store(

def prefetch(
buffer: Buffer, # pylint: disable=redefined-outer-name
bounds: List[Range],
bounds: List[ir.Range],
) -> None:
"""The prefetch hint for a buffer.
Expand Down Expand Up @@ -1579,7 +1580,7 @@ def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: # pylint: disable=redefined-buil
return _ffi_api.max(a, b) # type: ignore[attr-defined] # pylint: disable=no-member


def iter_var(v: Union[Var, str], dom: Range, iter_type: str, thread_tag: str) -> IterVar:
def iter_var(v: Union[Var, str], dom: ir.Range, iter_type: str, thread_tag: str) -> IterVar:
"""The iteration variable.
Parameters
Expand Down Expand Up @@ -1666,6 +1667,21 @@ def target(target_config: Union[Dict, str]) -> Target:
return Target(target_config)


def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: # pylint: disable=invalid-name
"""
Create a Range object.
Parameters
----------
begin : PrimExpr
The begin value of the range.
end : Optional[PrimExpr]
The end value of the range.
"""
return ir.Range(begin, end)


class meta_var: # pylint: disable=invalid-name
"""A meta variable used in TVMScript metaprogramming. It means that the value of the variable
does not appear in the final TIR, but only stays in the parser.
Expand Down Expand Up @@ -1782,6 +1798,11 @@ def wrapped(*args, **kwargs):
tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync)
tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment)
tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync)
tvm_storage_sync = _tir_op.tvm_storage_sync
tvm_warp_shuffle = _tir_op.tvm_warp_shuffle
tvm_warp_shuffle_up = _tir_op.tvm_warp_shuffle_up
tvm_warp_shuffle_down = _tir_op.tvm_warp_shuffle_down
tvm_warp_activemask = _tir_op.tvm_warp_activemask
ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
assume = _op_wrapper(_tir_op.assume)
Expand Down Expand Up @@ -2042,6 +2063,11 @@ def wrapped(*args, **kwargs):
"tvm_bmma_sync",
"tvm_fill_fragment",
"tvm_store_matrix_sync",
"tvm_storage_sync",
"tvm_warp_shuffle",
"tvm_warp_shuffle_up",
"tvm_warp_shuffle_down",
"tvm_warp_activemask",
"ptx_mma",
"ptx_mma_sp",
"ptx_ldmatrix",
Expand Down Expand Up @@ -2109,4 +2135,5 @@ def wrapped(*args, **kwargs):
"Let",
"IterVar",
"CommReducer",
"Range",
]
108 changes: 107 additions & 1 deletion python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,8 @@ def lookup_param(param_name, span=None):


def tvm_thread_allreduce(*freduce_args):
"""
"""Perform allreduce inside threadblock.
Parameters
----------
freduce_args : Expr
Expand All @@ -583,6 +584,111 @@ def tvm_thread_allreduce(*freduce_args):
return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args)


def tvm_storage_sync(storage_scope):
"""Perform synchronization in specified scope.
Parameters
----------
storage_scope : str
The storage scope to perform synchronization.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.tvm_storage_sync", storage_scope)


def tvm_warp_shuffle(mask, value, warp_id, width, warp_size):
"""Exchange value between threads inside a warp.
Parameters
----------
mask : PrimExpr
The warp mask indicates active threads inside warp.
value : PrimExpr
The value to exchange.
warp_id : PrimExpr
The source lane index to fetch value.
width : PrimExpr
The width of sub-sections to perform warp shuffle.
warp_size : PrimExpr
The warp size.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(value.dtype, "tir.tvm_warp_shuffle", mask, value, warp_id, width, warp_size)


def tvm_warp_shuffle_up(mask, value, offset, width, warp_size):
"""Copy value from a lane with lower (by offset) index relative to caller.
Parameters
----------
mask : PrimExpr
The warp mask indicates active threads inside warp.
value : PrimExpr
The value to exchange.
offset : PrimExpr
The difference between source lane index and destination lane index:
`offset = dst_lane_idx - src_lane_idx`
width : PrimExpr
The width of sub-sections to perform warp shuffle.
warp_size : PrimExpr
The warp size.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
value.dtype, "tir.tvm_warp_shuffle_up", mask, value, offset, width, warp_size
)


def tvm_warp_shuffle_down(mask, value, offset, width, warp_size):
"""Copy value from a lane with higher (by offset) index relative to caller.
Parameters
----------
mask : PrimExpr
The warp mask indicates active threads inside warp.
value : PrimExpr
The value to exchange.
offset : PrimExpr
The difference between source lane index and destination lane index:
`offset = src_lane_idx - dst_lane_idx`
width : PrimExpr
The width of sub-sections to perform warp shuffle.
warp_size : PrimExpr
The warp size.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
value.dtype, "tir.tvm_warp_shuffle_down", mask, value, offset, width, warp_size
)


def tvm_warp_activemask():
"""Return a 32-bit mask indicates currently active threads in a calling warp.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("uint32", "tir.tvm_warp_activemask")


def type_annotation(dtype):
"""Create a type annotation expression
Expand Down
55 changes: 55 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3623,6 +3623,60 @@ def main(A: T.handle, B: T.handle):
return main


def tvm_shfl_builtins():
@T.prim_func
def func(
A: T.handle("float32"),
B: T.handle("float32"),
C: T.handle("float32"),
):
blockIdx_x = T.launch_thread("blockIdx.x", 1)
threadIdx_x = T.launch_thread("threadIdx.x", 32)
A_warp = T.allocate([1], "float32", "local")
B_warp = T.allocate([1], "float32", "local")
red_buf0 = T.allocate([1], "float32", "local")
A_warp_1 = T.Buffer((32,), data=A_warp, scope="local")
A_1 = T.Buffer((32,), data=A)
A_warp_1[0] = A_1[threadIdx_x]
B_warp_1 = T.Buffer((32,), data=B_warp, scope="local")
T.tvm_storage_sync("warp")
B_warp_1[0] = T.tvm_warp_shuffle(
T.tvm_warp_activemask(), A_warp_1[0], threadIdx_x % 4 * 8 + threadIdx_x // 4, 32, 32
) + T.float32(1)
red_buf0_1 = T.Buffer((1,), data=red_buf0, scope="local")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
mask = T.allocate([1], "uint32", "local")
t0 = T.allocate([1], "float32", "local")
red_buf0_1[0] = A_warp_1[0]
mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local")
mask_1[0] = T.tvm_warp_activemask()
t0_1 = T.Buffer((1,), data=t0, scope="local")
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16, 32, 32)
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 8, 32, 32)
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 32)
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 32)
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32)
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
red_buf0_1[0] = T.tvm_warp_shuffle(mask_1[0], red_buf0_1[0], 0, 32, 32)
# NOTE(Zihao): test tvm_warp_shuffle_up
red_buf0_1[0] = T.tvm_warp_shuffle_up(mask_1[0], red_buf0_1[0], 0, 32, 32)
if threadIdx_x == 0:
C_1 = T.Buffer((1,), data=C)
C_1[0] = red_buf0_1[0]
B_1 = T.Buffer((32,), data=B)
B_1[threadIdx_x] = B_warp_1[0]

return func


ir_generator = tvm.testing.parameter(
launch_env_thread,
opt_gemm_normalize,
Expand Down Expand Up @@ -3686,6 +3740,7 @@ def main(A: T.handle, B: T.handle):
let_stmt_value,
string_stride,
merge_shape_var_def,
tvm_shfl_builtins,
)


Expand Down

0 comments on commit caf6b03

Please sign in to comment.