From 9ff063953838388c7d8884c8c0a737dcccd1a64c Mon Sep 17 00:00:00 2001 From: Wang Yucheng Date: Thu, 27 May 2021 04:00:08 +0800 Subject: [PATCH] [Codegen][CUDA] Fix make_int4x cuda codegen vectorize (#8137) Co-authored-by: wangyucheng --- src/target/source/codegen_c.cc | 15 ++++--- src/target/source/codegen_cuda.cc | 44 ++++++++++++++++--- .../unittest/test_target_codegen_cuda.py | 9 +++- 3 files changed, 55 insertions(+), 13 deletions(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index d4d0e54c6db4..99c9452975d4 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -212,13 +212,18 @@ std::string CodeGenC::GetBufferRef(DataType t, const VarNode* buffer, PrimExpr i PrintType(t.element_of(), os); os << "*)"; } - os << vid << " + ("; - PrintExpr(index, os); - os << ")"; if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { - os << " / " << (32 / t.bits()); + os << vid << ") + ("; + PrintExpr(index, os); + os << ")"; + os << " / " << t.lanes(); + os << ")[0]"; + } else { + os << vid << " + ("; + PrintExpr(index, os); + os << ")"; + os << "))[0]"; } - os << "))[0]"; } return os.str(); } diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 4cc999bf9136..6e76c3538e71 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -809,18 +809,48 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO return; } - if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4 && op->lanes == 8) { - // make_int4x8 + if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) { + bool fail = false; const int64_t* p = as_const_int(op->value); ICHECK(p); int64_t v = *p & 0xF; - v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v; - if (op->dtype.is_uint()) { - os << "(uint)" << v; + + if (op->lanes == 4) { + v = (v << 12) | (v << 8) | (v << 4) | v; + if (op->dtype.is_uint()) { + os << "(uint16_t)" << v; + } else { + os << "(int16_t)" << v; + } } else { - os << "(int)" << v; + v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v; + if (op->lanes == 8) { + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + } else if (op->lanes == 16 || op->lanes == 32) { + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < op->lanes / 8; ++i) { + if (i != 0) os << ", "; + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + } + os << ')'; + } else { + fail = true; + } + } + + if (!fail) { + return; } - return; } std::string v = PrintExpr(op->value); diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index fc138bb43f1a..56ba9a085ffc 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -215,14 +215,21 @@ def check_cuda(n, value, lanes): y, x = s[A].op.axis s[A].vectorize(x) s[A].bind(y, bx) - fun = tvm.build(s, [A], "cuda", name="make_int4x8") + kernel_name = "make_int4x" + str(lanes) + fun = tvm.build(s, [A], "cuda", name=kernel_name) np_a = np.full((n, lanes), value, dtype="int8") a = tvm.nd.empty((n, lanes), dtype, dev) fun(a) np.testing.assert_equal(a.numpy(), np_a) + check_cuda(64, 1, 4) + check_cuda(64, 7, 4) check_cuda(64, 1, 8) check_cuda(64, 7, 8) + check_cuda(64, 1, 16) + check_cuda(64, 7, 16) + check_cuda(64, 1, 32) + check_cuda(64, 7, 32) @tvm.testing.requires_gpu