From 0a3c7365915908837a41f99fc45585fdf06d9a65 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Sun, 21 Jan 2024 22:53:26 +0800 Subject: [PATCH] [Unity][FIX] fix thread dtype mismatch (#16443) --- python/tvm/topi/cuda/scatter_elements.py | 2 +- src/te/operation/op_utils.cc | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index 2f345b9d67ec..27567ea23e21 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -168,7 +168,7 @@ def gen_ir(data, indices, updates, out, axis, reduce_func): max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) # Copy initial input data to output with ib.new_scope(): - num_blocks = ceil_div(full_range, max_threads) + num_blocks = cast(ceil_div(full_range, max_threads), "int32") bx = te.thread_axis("blockIdx.x") tx = te.thread_axis("threadIdx.x") ib.scope_attr(bx, "thread_extent", num_blocks) diff --git a/src/te/operation/op_utils.cc b/src/te/operation/op_utils.cc index 1f386bc2dd67..7168933a320c 100644 --- a/src/te/operation/op_utils.cc +++ b/src/te/operation/op_utils.cc @@ -155,22 +155,24 @@ std::vector> MakeLoopNest(const Stage& stage, ICHECK(is_zero(dom->min)); ICHECK(is_positive_const(dom->extent)); // annotate the extent of the IterVar - nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::virtual_thread, dom->extent, no_op)); + nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::virtual_thread, + cast(bind_iv->var.dtype(), dom->extent), no_op)); value_map[iv] = promote_to_iv_dtype(var); } else if (bind_iv->thread_tag == "pipeline") { // pipeline marker. ICHECK(is_zero(dom->min)); ICHECK(is_one(dom->extent)); // annotate the extent of the IterVar - nest[i + 1].emplace_back( - AttrStmt(bind_iv, tir::attr::pipeline_exec_scope, dom->extent, no_op)); + nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::pipeline_exec_scope, + cast(bind_iv->var.dtype(), dom->extent), no_op)); value_map[iv] = dom->min; } else { // Always restrict threaded IterVar to starts from 0. ICHECK(is_zero(dom->min)) << "Itervar " << iv << " must start at zero, but it starts at " << dom->min; // annotate the extent of the IterVar - nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent, dom->extent, no_op)); + nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent, + cast(bind_iv->var.dtype(), dom->extent), no_op)); if (!debug_keep_trivial_loop && is_one(dom->extent)) { value_map[iv] = dom->min; } else if (stage->scope == "") {