-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[TIR] Shuffle in PointerValueTypeRewrite for scalar reads #15517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR] Shuffle in PointerValueTypeRewrite for scalar reads #15517
Conversation
Added an option `rewrite_scalar_read_to_vector_shuffle` in `PointerValueTypeRewrite` (currently only enabled for Vulkan). When enabled, when a buffer has both scalar and vector reads, the buffer will be vectorized if possible and scalar reads will be achieved via T.Shuffle. Close apache#15463.
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
|
@Lunderberg do you mind help reviewing this |
|
I can, though I probably won't have time available to do so until later next week. |
|
ah ok, i can help to take a look then |
tqchen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks wuwei, some comment
| if (me->coeff > 0) { | ||
| // When coeff == 0, the index is constant and doesn't need to be recorded since it can | ||
| // always be rewritten to shuffle. | ||
| var_info.access_dtype.insert(access_dtype.with_lanes(me->coeff)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we need a data structure that captures two things:
struct VarReadInfo {
DataType access_dtype;
// maintained as GCD of all coef of access index
int64_t simd_coeff;
};This is mainly to ensure that we don't end up use a very large vector here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
access_dtype also contains dtype for write, so the result dtype will be eventually bound by the vectorization lanes of writes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the main corner case would be when we do not have writes and only reads.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
e523259 to
9a90ec1
Compare
|
Hi @vinx13, we noticed that this PR breaks the WebGPU codegen as the WebGPU codegen right now does not support tir::ShuffleNode. Therefore, exceptions are thrown in pass Here is a reproducible example https://gist.github.com/MasterJH5574/3e6141a1d6dcafd383ed6f7ac27e3318. Need to run it under this commit mlc-ai/relax@631f37b since the WebGPU codegen in mlc-ai/relax is slightly different from the codegen here due to some conflicts. cc @tqchen PrimFunc after PointerValueTypeRewrite: # from tvm.script import tir as T
@T.prim_func
def fused_fused_decode4_NT_matmul9_kernel(lv1654: T.handle("float16x4", "global"), lv635: T.handle("uint32", "global"), lv636: T.handle("float16", "global"), var_NT_matmul_intermediate: T.handle("float16", "global")):
T.func_attr({"calling_conv": 2, "target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mtriple": "wasm32-unknown-unknown-wasm", "tag": ""}, "keys": ["webgpu", "gpu"], "kind": "webgpu", "max_num_threads": 256, "tag": ""}), "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "threadIdx.x", "threadIdx.y"], "tir.noalias": T.bool(True)})
var_NT_matmul_intermediate_1 = T.decl_buffer((22016,), "float16", data=var_NT_matmul_intermediate)
red_buf0 = T.handle("float16", "shared")
red_buf0_1 = T.decl_buffer((256,), "float16", data=red_buf0, scope="shared")
var_NT_matmul_intermediate_rf_local_1 = T.handle("float16", "local")
var_NT_matmul_intermediate_rf_local_1_1 = T.decl_buffer((1,), "float16", data=var_NT_matmul_intermediate_rf_local_1, scope="local")
lv636_1 = T.decl_buffer((2818048,), "float16", data=lv636)
lv635_1 = T.decl_buffer((11272192,), "uint32", data=lv635)
lv635_local = T.handle("uint32", "local")
lv635_local_1 = T.decl_buffer((1,), "uint32", data=lv635_local, scope="local")
var_NT_matmul_intermediate_rf_local = T.handle("float16x2", "local")
var_NT_matmul_intermediate_rf_local_2 = T.decl_buffer((2,), "float16", data=var_NT_matmul_intermediate_rf_local, scope="local")
lv1654_1 = T.decl_buffer((4096,), "float16", data=lv1654)
lv1654_shared = T.handle("float16", "shared")
lv1654_shared_1 = T.decl_buffer((4096,), "float16", data=lv1654_shared, scope="shared")
blockIdx_x = T.launch_thread("blockIdx.x", 688)
lv1654_shared = T.allocate([4096], "float16", "shared")
var_NT_matmul_intermediate_rf_local = T.allocate([1], "float16x2", "local")
lv635_local = T.allocate([1], "uint32", "local")
var_NT_matmul_intermediate_rf_local_1 = T.allocate([1], "float16", "local")
red_buf0 = T.allocate([256], "float16", "shared")
T.attr(red_buf0, "volatile_scope", 1)
threadIdx_x = T.launch_thread("threadIdx.x", 8)
threadIdx_y = T.launch_thread("threadIdx.y", 32)
ax2_0 = T.int32()
with T.attr(ax2_0, "pragma_vectorize", 1):
for ax2_0 in range(4):
lv1654_2 = T.Buffer((1024,), "float16x4", data=lv1654)
lv1654_shared_1[ax2_0 * 1024 + threadIdx_y * 32 + threadIdx_x * 4:ax2_0 * 1024 + threadIdx_y * 32 + threadIdx_x * 4 + 4] = lv1654_2[T.Div(ax2_0 * 1024 + threadIdx_y * 32 + threadIdx_x * 4, 4)]
var_NT_matmul_intermediate_rf_local_3 = T.Buffer((1,), "float16x2", data=var_NT_matmul_intermediate_rf_local, scope="local")
var_NT_matmul_intermediate_rf_local_3[0] = T.Broadcast(T.float16(0), 2)
T.tvm_storage_sync("shared")
for ax1_0_fused_ax1_1_fused_0 in range(64):
lv635_local_1[0] = lv635_1[blockIdx_x * 16384 + threadIdx_y * 512 + ax1_0_fused_ax1_1_fused_0 * 8 + threadIdx_x]
var_NT_matmul_intermediate_rf_local_3[0] = T.call_pure_extern("float16x2", "fma", lv1654_shared_1[ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8:ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 2], (T.Cast("float16x2", T.bitwise_and(T.shift_right(T.Broadcast(lv635_local_1[0], 2), T.Ramp(T.uint32(0), T.uint32(4), 2)), T.Broadcast(T.uint32(15), 2))) - T.Broadcast(T.float16(7), 2)) * T.Broadcast(lv636_1[blockIdx_x * 4096 + threadIdx_y * 128 + ax1_0_fused_ax1_1_fused_0 * 2 + T.shift_right(threadIdx_x, 2)], 2), var_NT_matmul_intermediate_rf_local_3[0])
var_NT_matmul_intermediate_rf_local_3[0] = T.call_pure_extern("float16x2", "fma", lv1654_shared_1[ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 2:ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 2 + 2], (T.Cast("float16x2", T.bitwise_and(T.shift_right(T.Broadcast(lv635_local_1[0], 2), T.Ramp(T.uint32(8), T.uint32(4), 2)), T.Broadcast(T.uint32(15), 2))) - T.Broadcast(T.float16(7), 2)) * T.Broadcast(lv636_1[blockIdx_x * 4096 + threadIdx_y * 128 + ax1_0_fused_ax1_1_fused_0 * 2 + T.shift_right(threadIdx_x, 2)], 2), var_NT_matmul_intermediate_rf_local_3[0])
var_NT_matmul_intermediate_rf_local_3[0] = T.call_pure_extern("float16x2", "fma", lv1654_shared_1[ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 4:ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 4 + 2], (T.Cast("float16x2", T.bitwise_and(T.shift_right(T.Broadcast(lv635_local_1[0], 2), T.Ramp(T.uint32(16), T.uint32(4), 2)), T.Broadcast(T.uint32(15), 2))) - T.Broadcast(T.float16(7), 2)) * T.Broadcast(lv636_1[blockIdx_x * 4096 + threadIdx_y * 128 + ax1_0_fused_ax1_1_fused_0 * 2 + T.shift_right(threadIdx_x, 2)], 2), var_NT_matmul_intermediate_rf_local_3[0])
var_NT_matmul_intermediate_rf_local_3[0] = T.call_pure_extern("float16x2", "fma", lv1654_shared_1[ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 6:ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 6 + 2], (T.Cast("float16x2", T.bitwise_and(T.shift_right(T.Broadcast(lv635_local_1[0], 2), T.Ramp(T.uint32(24), T.uint32(4), 2)), T.Broadcast(T.uint32(15), 2))) - T.Broadcast(T.float16(7), 2)) * T.Broadcast(lv636_1[blockIdx_x * 4096 + threadIdx_y * 128 + ax1_0_fused_ax1_1_fused_0 * 2 + T.shift_right(threadIdx_x, 2)], 2), var_NT_matmul_intermediate_rf_local_3[0])
var_NT_matmul_intermediate_rf_local_1_1[0] = T.float16(0)
var_NT_matmul_intermediate_rf_local_1_1[0] = var_NT_matmul_intermediate_rf_local_1_1[0] + T.Shuffle([var_NT_matmul_intermediate_rf_local_3[0]], [0])
var_NT_matmul_intermediate_rf_local_1_1[0] = var_NT_matmul_intermediate_rf_local_1_1[0] + T.Shuffle([var_NT_matmul_intermediate_rf_local_3[0]], [1])
with T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float16(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))):
T.tvm_storage_sync("shared")
red_buf0_1[threadIdx_y * 8 + threadIdx_x] = var_NT_matmul_intermediate_rf_local_1_1[0]
T.tvm_storage_sync("shared")
if threadIdx_x < 4:
red_buf0_1[threadIdx_y * 8 + threadIdx_x] = red_buf0_1[threadIdx_y * 8 + threadIdx_x] + red_buf0_1[threadIdx_y * 8 + threadIdx_x + 4]
T.tvm_storage_sync("shared")
if threadIdx_x < 2:
red_buf0_1[threadIdx_y * 8 + threadIdx_x] = red_buf0_1[threadIdx_y * 8 + threadIdx_x] + red_buf0_1[threadIdx_y * 8 + threadIdx_x + 2]
T.tvm_storage_sync("shared")
if threadIdx_x < 1:
red_buf0_1[threadIdx_y * 8 + threadIdx_x] = red_buf0_1[threadIdx_y * 8 + threadIdx_x] + red_buf0_1[threadIdx_y * 8 + threadIdx_x + 1]
T.tvm_storage_sync("shared")
if threadIdx_x == 0:
var_NT_matmul_intermediate_1[blockIdx_x * 32 + threadIdx_y] = red_buf0_1[threadIdx_y * 8]Codegen error message: |
|
Is it possible to support it in codegen? Usually this can be supported via element extraction e.g |
|
I think we should support via codegen |
…ache#15517)" This reverts commit 925148e.
|
Will look into adding Shuffle support for WebGPU! |
…ache#15517)" This reverts commit 925148e.
…ache#15517)" This reverts commit 925148e.
…ache#15517)" This reverts commit 925148e.
…ache#15517)" This reverts commit 925148e.
|
Hi @vinx13 here's a minimum reproducible script, should reproduce the same error running on up-to-date unity branch: https://gist.github.com/guoyaol/af7d1161124987b69b8eb2744b1c399e |
|
@guoyaol it may be unaligned vectorizations or a false alarm when the arithmetic analyzer can't handle the index pattern |
Added an option
rewrite_scalar_read_to_vector_shuffleinPointerValueTypeRewrite(currentlyonly enabled for Vulkan). When enabled, when a buffer has both scalar
and vector reads, the buffer will be vectorized if possible and scalar
reads will be achieved via T.Shuffle.
Close #15463.
cc @tqchen @Lunderberg @sunggg