Skip to content

Conversation

@vinx13
Copy link
Member

@vinx13 vinx13 commented Aug 9, 2023

Added an option rewrite_scalar_read_to_vector_shuffle in PointerValueTypeRewrite (currently
only enabled for Vulkan). When enabled, when a buffer has both scalar
and vector reads, the buffer will be vectorized if possible and scalar
reads will be achieved via T.Shuffle.

Close #15463.

cc @tqchen @Lunderberg @sunggg

Added an option `rewrite_scalar_read_to_vector_shuffle` in `PointerValueTypeRewrite` (currently
only enabled for Vulkan). When enabled, when a buffer has both scalar
and vector reads, the buffer will be vectorized if possible and scalar
reads will be achieved via T.Shuffle.

Close apache#15463.
@tvm-bot
Copy link
Collaborator

tvm-bot commented Aug 9, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@github-actions github-actions bot requested review from Lunderberg and tqchen August 9, 2023 22:16
@tqchen
Copy link
Member

tqchen commented Aug 9, 2023

@Lunderberg do you mind help reviewing this

@Lunderberg
Copy link
Contributor

I can, though I probably won't have time available to do so until later next week.

@tqchen
Copy link
Member

tqchen commented Aug 10, 2023

ah ok, i can help to take a look then

Copy link
Member

@tqchen tqchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks wuwei, some comment

if (me->coeff > 0) {
// When coeff == 0, the index is constant and doesn't need to be recorded since it can
// always be rewritten to shuffle.
var_info.access_dtype.insert(access_dtype.with_lanes(me->coeff));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we need a data structure that captures two things:

struct VarReadInfo {
   DataType access_dtype;
   // maintained as GCD of all coef of access index
   int64_t simd_coeff;
};

This is mainly to ensure that we don't end up use a very large vector here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

access_dtype also contains dtype for write, so the result dtype will be eventually bound by the vectorization lanes of writes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the main corner case would be when we do not have writes and only reads.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

@vinx13 vinx13 force-pushed the feat/pointer-value-type-rewrite-shuffle branch from e523259 to 9a90ec1 Compare August 14, 2023 18:57
@tqchen tqchen merged commit 925148e into apache:main Aug 18, 2023
@MasterJH5574
Copy link
Contributor

MasterJH5574 commented Sep 17, 2023

Hi @vinx13, we noticed that this PR breaks the WebGPU codegen as the WebGPU codegen right now does not support tir::ShuffleNode. Therefore, exceptions are thrown in pass PointerValueTypeRewrite at the beginning of WebGPU codegen.

Here is a reproducible example https://gist.github.com/MasterJH5574/3e6141a1d6dcafd383ed6f7ac27e3318. Need to run it under this commit mlc-ai/relax@631f37b since the WebGPU codegen in mlc-ai/relax is slightly different from the codegen here due to some conflicts.

cc @tqchen


PrimFunc after PointerValueTypeRewrite:

# from tvm.script import tir as T

@T.prim_func
def fused_fused_decode4_NT_matmul9_kernel(lv1654: T.handle("float16x4", "global"), lv635: T.handle("uint32", "global"), lv636: T.handle("float16", "global"), var_NT_matmul_intermediate: T.handle("float16", "global")):
    T.func_attr({"calling_conv": 2, "target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mtriple": "wasm32-unknown-unknown-wasm", "tag": ""}, "keys": ["webgpu", "gpu"], "kind": "webgpu", "max_num_threads": 256, "tag": ""}), "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "threadIdx.x", "threadIdx.y"], "tir.noalias": T.bool(True)})
    var_NT_matmul_intermediate_1 = T.decl_buffer((22016,), "float16", data=var_NT_matmul_intermediate)
    red_buf0 = T.handle("float16", "shared")
    red_buf0_1 = T.decl_buffer((256,), "float16", data=red_buf0, scope="shared")
    var_NT_matmul_intermediate_rf_local_1 = T.handle("float16", "local")
    var_NT_matmul_intermediate_rf_local_1_1 = T.decl_buffer((1,), "float16", data=var_NT_matmul_intermediate_rf_local_1, scope="local")
    lv636_1 = T.decl_buffer((2818048,), "float16", data=lv636)
    lv635_1 = T.decl_buffer((11272192,), "uint32", data=lv635)
    lv635_local = T.handle("uint32", "local")
    lv635_local_1 = T.decl_buffer((1,), "uint32", data=lv635_local, scope="local")
    var_NT_matmul_intermediate_rf_local = T.handle("float16x2", "local")
    var_NT_matmul_intermediate_rf_local_2 = T.decl_buffer((2,), "float16", data=var_NT_matmul_intermediate_rf_local, scope="local")
    lv1654_1 = T.decl_buffer((4096,), "float16", data=lv1654)
    lv1654_shared = T.handle("float16", "shared")
    lv1654_shared_1 = T.decl_buffer((4096,), "float16", data=lv1654_shared, scope="shared")
    blockIdx_x = T.launch_thread("blockIdx.x", 688)
    lv1654_shared = T.allocate([4096], "float16", "shared")
    var_NT_matmul_intermediate_rf_local = T.allocate([1], "float16x2", "local")
    lv635_local = T.allocate([1], "uint32", "local")
    var_NT_matmul_intermediate_rf_local_1 = T.allocate([1], "float16", "local")
    red_buf0 = T.allocate([256], "float16", "shared")
    T.attr(red_buf0, "volatile_scope", 1)
    threadIdx_x = T.launch_thread("threadIdx.x", 8)
    threadIdx_y = T.launch_thread("threadIdx.y", 32)
    ax2_0 = T.int32()
    with T.attr(ax2_0, "pragma_vectorize", 1):
        for ax2_0 in range(4):
            lv1654_2 = T.Buffer((1024,), "float16x4", data=lv1654)
            lv1654_shared_1[ax2_0 * 1024 + threadIdx_y * 32 + threadIdx_x * 4:ax2_0 * 1024 + threadIdx_y * 32 + threadIdx_x * 4 + 4] = lv1654_2[T.Div(ax2_0 * 1024 + threadIdx_y * 32 + threadIdx_x * 4, 4)]
    var_NT_matmul_intermediate_rf_local_3 = T.Buffer((1,), "float16x2", data=var_NT_matmul_intermediate_rf_local, scope="local")
    var_NT_matmul_intermediate_rf_local_3[0] = T.Broadcast(T.float16(0), 2)
    T.tvm_storage_sync("shared")
    for ax1_0_fused_ax1_1_fused_0 in range(64):
        lv635_local_1[0] = lv635_1[blockIdx_x * 16384 + threadIdx_y * 512 + ax1_0_fused_ax1_1_fused_0 * 8 + threadIdx_x]
        var_NT_matmul_intermediate_rf_local_3[0] = T.call_pure_extern("float16x2", "fma", lv1654_shared_1[ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8:ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 2], (T.Cast("float16x2", T.bitwise_and(T.shift_right(T.Broadcast(lv635_local_1[0], 2), T.Ramp(T.uint32(0), T.uint32(4), 2)), T.Broadcast(T.uint32(15), 2))) - T.Broadcast(T.float16(7), 2)) * T.Broadcast(lv636_1[blockIdx_x * 4096 + threadIdx_y * 128 + ax1_0_fused_ax1_1_fused_0 * 2 + T.shift_right(threadIdx_x, 2)], 2), var_NT_matmul_intermediate_rf_local_3[0])
        var_NT_matmul_intermediate_rf_local_3[0] = T.call_pure_extern("float16x2", "fma", lv1654_shared_1[ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 2:ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 2 + 2], (T.Cast("float16x2", T.bitwise_and(T.shift_right(T.Broadcast(lv635_local_1[0], 2), T.Ramp(T.uint32(8), T.uint32(4), 2)), T.Broadcast(T.uint32(15), 2))) - T.Broadcast(T.float16(7), 2)) * T.Broadcast(lv636_1[blockIdx_x * 4096 + threadIdx_y * 128 + ax1_0_fused_ax1_1_fused_0 * 2 + T.shift_right(threadIdx_x, 2)], 2), var_NT_matmul_intermediate_rf_local_3[0])
        var_NT_matmul_intermediate_rf_local_3[0] = T.call_pure_extern("float16x2", "fma", lv1654_shared_1[ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 4:ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 4 + 2], (T.Cast("float16x2", T.bitwise_and(T.shift_right(T.Broadcast(lv635_local_1[0], 2), T.Ramp(T.uint32(16), T.uint32(4), 2)), T.Broadcast(T.uint32(15), 2))) - T.Broadcast(T.float16(7), 2)) * T.Broadcast(lv636_1[blockIdx_x * 4096 + threadIdx_y * 128 + ax1_0_fused_ax1_1_fused_0 * 2 + T.shift_right(threadIdx_x, 2)], 2), var_NT_matmul_intermediate_rf_local_3[0])
        var_NT_matmul_intermediate_rf_local_3[0] = T.call_pure_extern("float16x2", "fma", lv1654_shared_1[ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 6:ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 6 + 2], (T.Cast("float16x2", T.bitwise_and(T.shift_right(T.Broadcast(lv635_local_1[0], 2), T.Ramp(T.uint32(24), T.uint32(4), 2)), T.Broadcast(T.uint32(15), 2))) - T.Broadcast(T.float16(7), 2)) * T.Broadcast(lv636_1[blockIdx_x * 4096 + threadIdx_y * 128 + ax1_0_fused_ax1_1_fused_0 * 2 + T.shift_right(threadIdx_x, 2)], 2), var_NT_matmul_intermediate_rf_local_3[0])
    var_NT_matmul_intermediate_rf_local_1_1[0] = T.float16(0)
    var_NT_matmul_intermediate_rf_local_1_1[0] = var_NT_matmul_intermediate_rf_local_1_1[0] + T.Shuffle([var_NT_matmul_intermediate_rf_local_3[0]], [0])
    var_NT_matmul_intermediate_rf_local_1_1[0] = var_NT_matmul_intermediate_rf_local_1_1[0] + T.Shuffle([var_NT_matmul_intermediate_rf_local_3[0]], [1])
    with T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float16(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))):
        T.tvm_storage_sync("shared")
        red_buf0_1[threadIdx_y * 8 + threadIdx_x] = var_NT_matmul_intermediate_rf_local_1_1[0]
        T.tvm_storage_sync("shared")
        if threadIdx_x < 4:
            red_buf0_1[threadIdx_y * 8 + threadIdx_x] = red_buf0_1[threadIdx_y * 8 + threadIdx_x] + red_buf0_1[threadIdx_y * 8 + threadIdx_x + 4]
        T.tvm_storage_sync("shared")
        if threadIdx_x < 2:
            red_buf0_1[threadIdx_y * 8 + threadIdx_x] = red_buf0_1[threadIdx_y * 8 + threadIdx_x] + red_buf0_1[threadIdx_y * 8 + threadIdx_x + 2]
        T.tvm_storage_sync("shared")
        if threadIdx_x < 1:
            red_buf0_1[threadIdx_y * 8 + threadIdx_x] = red_buf0_1[threadIdx_y * 8 + threadIdx_x] + red_buf0_1[threadIdx_y * 8 + threadIdx_x + 1]
        T.tvm_storage_sync("shared")
    if threadIdx_x == 0:
        var_NT_matmul_intermediate_1[blockIdx_x * 32 + threadIdx_y] = red_buf0_1[threadIdx_y * 8]

Codegen error message:

Traceback (most recent call last):
  File "/ssd1/ruihangl/workspace/mlc-llm/workspace/webgpu.py", line 169, in <module>
    lib = tvm.build(
  File "/home/ruihangl/tvm/python/tvm/driver/build_module.py", line 281, in build
    rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
  File "/home/ruihangl/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/home/ruihangl/tvm/python/tvm/_ffi/base.py", line 476, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /home/ruihangl/tvm/build/libtvm.so(tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::AttrStmtNode const*)+0x31a) [0x7f7a8771fc7a]
  [bt] (7) /home/ruihangl/tvm/build/libtvm.so(tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::SeqStmtNode const*)+0x94) [0x7f7a87720504]
  [bt] (6) /home/ruihangl/tvm/build/libtvm.so(tvm::codegen::CodeGenWebGPU::VisitStmt_(tvm::tir::BufferStoreNode const*)+0x12d) [0x7f7a877580fd]
  [bt] (5) /home/ruihangl/tvm/build/libtvm.so(tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)+0x30) [0x7f7a87720de0]
  [bt] (4) /home/ruihangl/tvm/build/libtvm.so(void tvm::codegen::PrintBinaryExpr<tvm::tir::AddNode>(tvm::tir::AddNode const*, char const*, std::ostream&, tvm::codegen::CodeGenC*)+0x15f) [0x7f7a8772158f]
  [bt] (3) /home/ruihangl/tvm/build/libtvm.so(tvm::codegen::CodeGenC::VisitExpr_(tvm::tir::ShuffleNode const*, std::ostream&)+0x31) [0x7f7a8771ee21]
  [bt] (2) /home/ruihangl/tvm/build/libtvm.so(+0x1572726) [0x7f7a86344726]
  [bt] (1) /home/ruihangl/tvm/build/libtvm.so(tvm::runtime::detail::LogFatal::Entry::Finalize()+0xff) [0x7f7a86344f6f]
  [bt] (0) /home/ruihangl/tvm/build/libtvm.so(tvm::runtime::Backtrace[abi:cxx11]()+0x13) [0x7f7a883395c3]
  File "/home/ruihangl/tvm/src/target/source/codegen_c.cc", line 816
Shuffle: not supported 

@vinx13
Copy link
Member Author

vinx13 commented Sep 17, 2023

Is it possible to support it in codegen? Usually this can be supported via element extraction e.g vex.x/y/z. Alternatively we can set rewrite_scalar_access_to_vector_shuffle to false

@tqchen
Copy link
Member

tqchen commented Sep 17, 2023

I think we should support via codegen

MasterJH5574 added a commit to MasterJH5574/tvm that referenced this pull request Sep 17, 2023
@CharlieFRuan
Copy link
Member

Will look into adding Shuffle support for WebGPU!

MasterJH5574 added a commit to MasterJH5574/tvm that referenced this pull request Sep 19, 2023
MasterJH5574 added a commit to MasterJH5574/tvm that referenced this pull request Oct 2, 2023
jinhongyii pushed a commit to jinhongyii/tvm that referenced this pull request Oct 19, 2023
jinhongyii pushed a commit to jinhongyii/tvm that referenced this pull request Nov 3, 2023
@guoyaol
Copy link
Contributor

guoyaol commented Nov 15, 2023

Hi @vinx13
I was building TVM mod for stable diffusion, I found some kernels failed to pass the ICHECK in this PR. Wonder if you have any insights about how to tackle this? Thanks.

  File "/Users/guoyaoli/tvm_work/tvm/src/tir/transforms/storage_rewrite.cc", line 1466
InternalError: Check failed: (me->coeff == 0 || info.factor() % me->coeff == 0) is false: 

here's a minimum reproducible script, should reproduce the same error running on up-to-date unity branch: https://gist.github.com/guoyaol/af7d1161124987b69b8eb2744b1c399e

cc @MasterJH5574 @tqchen

@vinx13
Copy link
Member Author

vinx13 commented Nov 15, 2023

@guoyaol it may be unaligned vectorizations or a false alarm when the arithmetic analyzer can't handle the index pattern

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[TensorIR] Enhance PointerTypeRewrite

7 participants