Skip to content

Commit

Permalink
[Codegen][CUDA] Fix make_int4x cuda codegen vectorize (apache#8137)
Browse files Browse the repository at this point in the history
Co-authored-by: wangyucheng <wangyucheng@sensetime.com>
  • Loading branch information
2 people authored and Trevor Morris committed Jun 17, 2021
1 parent 371c2f8 commit 8117f3f
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 13 deletions.
15 changes: 10 additions & 5 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,18 @@ std::string CodeGenC::GetBufferRef(DataType t, const VarNode* buffer, PrimExpr i
PrintType(t.element_of(), os);
os << "*)";
}
os << vid << " + (";
PrintExpr(index, os);
os << ")";
if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
os << " / " << (32 / t.bits());
os << vid << ") + (";
PrintExpr(index, os);
os << ")";
os << " / " << t.lanes();
os << ")[0]";
} else {
os << vid << " + (";
PrintExpr(index, os);
os << ")";
os << "))[0]";
}
os << "))[0]";
}
return os.str();
}
Expand Down
44 changes: 37 additions & 7 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -809,18 +809,48 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO
return;
}

if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4 && op->lanes == 8) {
// make_int4x8
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) {
bool fail = false;
const int64_t* p = as_const_int(op->value);
ICHECK(p);
int64_t v = *p & 0xF;
v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v;
if (op->dtype.is_uint()) {
os << "(uint)" << v;

if (op->lanes == 4) {
v = (v << 12) | (v << 8) | (v << 4) | v;
if (op->dtype.is_uint()) {
os << "(uint16_t)" << v;
} else {
os << "(int16_t)" << v;
}
} else {
os << "(int)" << v;
v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v;
if (op->lanes == 8) {
if (op->dtype.is_uint()) {
os << "(uint)" << v;
} else {
os << "(int)" << v;
}
} else if (op->lanes == 16 || op->lanes == 32) {
os << "make_";
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < op->lanes / 8; ++i) {
if (i != 0) os << ", ";
if (op->dtype.is_uint()) {
os << "(uint)" << v;
} else {
os << "(int)" << v;
}
}
os << ')';
} else {
fail = true;
}
}

if (!fail) {
return;
}
return;
}

std::string v = PrintExpr(op->value);
Expand Down
9 changes: 8 additions & 1 deletion tests/python/unittest/test_target_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,21 @@ def check_cuda(n, value, lanes):
y, x = s[A].op.axis
s[A].vectorize(x)
s[A].bind(y, bx)
fun = tvm.build(s, [A], "cuda", name="make_int4x8")
kernel_name = "make_int4x" + str(lanes)
fun = tvm.build(s, [A], "cuda", name=kernel_name)
np_a = np.full((n, lanes), value, dtype="int8")
a = tvm.nd.empty((n, lanes), dtype, dev)
fun(a)
np.testing.assert_equal(a.numpy(), np_a)

check_cuda(64, 1, 4)
check_cuda(64, 7, 4)
check_cuda(64, 1, 8)
check_cuda(64, 7, 8)
check_cuda(64, 1, 16)
check_cuda(64, 7, 16)
check_cuda(64, 1, 32)
check_cuda(64, 7, 32)


@tvm.testing.requires_gpu
Expand Down

0 comments on commit 8117f3f

Please sign in to comment.