Skip to content

Commit eb1ed11

Browse files
ajtullochtqchen
authored andcommitted
Fix vcvtph2ps codegen (#2925)
1 parent 4344d4a commit eb1ed11

File tree

3 files changed

+221
-17
lines changed

3 files changed

+221
-17
lines changed

src/codegen/llvm/codegen_llvm.cc

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
*/
55
#ifdef TVM_LLVM_VERSION
66
// Part of the code are adapted from Halide's CodeGen_LLVM
7-
87
#include <tvm/runtime/device_api.h>
98
#include <tvm/runtime/c_runtime_api.h>
9+
10+
#include <algorithm>
11+
1012
#include "codegen_llvm.h"
1113
#include "codegen_cpu.h"
1214
#include "../../pass/ir_util.h"
@@ -410,12 +412,16 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
410412
llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
411413
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
412414
if (extent == num_elems && begin == 0) return vec;
413-
CHECK_LE(begin + extent, num_elems);
414-
std::vector<unsigned> indices;
415+
std::vector<llvm::Constant*> indices;
416+
indices.reserve(extent);
415417
for (int i = 0; i < extent; ++i) {
416-
indices.push_back(begin + i);
418+
if (begin + i >= 0 && begin + i < num_elems) {
419+
indices.push_back(llvm::ConstantInt::get(t_int32_, begin + i));
420+
} else {
421+
indices.push_back(llvm::UndefValue::get(t_int32_));
422+
}
417423
}
418-
return builder_->CreateShuffleVector(vec, vec, indices);
424+
return builder_->CreateShuffleVector(vec, vec, llvm::ConstantVector::get(indices));
419425
}
420426

421427
llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
@@ -446,24 +452,31 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
446452
v->getType()->getVectorNumElements());
447453
}
448454
while (vecs.size() > 1) {
449-
for (size_t i = 0; i < vecs.size(); i+=2) {
450-
if (i + 1 >= vecs.size()) {
451-
vecs[i / 2] = vecs[i]; continue;
452-
}
455+
std::vector<llvm::Value*> new_vecs;
456+
for (size_t i = 0; i < vecs.size() - 1; i += 2) {
453457
llvm::Value* lhs = vecs[i];
454458
llvm::Value* rhs = vecs[i + 1];
455-
int lanes = static_cast<int>(std::max(
456-
lhs->getType()->getVectorNumElements(),
457-
rhs->getType()->getVectorNumElements()));
458-
lhs = CreateVecPad(lhs, lanes);
459-
rhs = CreateVecPad(lhs, lanes);
459+
const size_t lhs_lanes = lhs->getType()->getVectorNumElements();
460+
const size_t rhs_lanes = rhs->getType()->getVectorNumElements();
461+
if (lhs_lanes < rhs_lanes) {
462+
lhs = CreateVecPad(lhs, rhs_lanes);
463+
} else if (rhs_lanes < lhs_lanes) {
464+
rhs = CreateVecPad(rhs, lhs_lanes);
465+
}
466+
const size_t shared_lanes = std::max(lhs_lanes, rhs_lanes);
460467
std::vector<unsigned> mask;
461-
for (int i = 0; i < lanes * 2; ++i) {
468+
for (size_t i = 0; i < lhs_lanes; ++i) {
462469
mask.push_back(i);
463470
}
464-
vecs[i / 2] = builder_->CreateShuffleVector(lhs, rhs, mask);
471+
for (size_t i = 0; i < rhs_lanes; ++i) {
472+
mask.push_back(shared_lanes + i);
473+
}
474+
new_vecs.push_back(builder_->CreateShuffleVector(lhs, rhs, mask));
475+
}
476+
if (vecs.size() % 2 != 0) {
477+
new_vecs.push_back(vecs.back());
465478
}
466-
vecs.resize((vecs.size() + 1) / 2);
479+
vecs.swap(new_vecs);
467480
}
468481
return CreateVecSlice(vecs[0], 0, total_lanes);
469482
}

src/codegen/llvm/codegen_x86_64.cc

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*!
2+
* Copyright (c) 2019 by Contributors
3+
* \file codegen_x86_64.cc
4+
* \brief X86-64 specific code generator
5+
*/
6+
#ifdef TVM_LLVM_VERSION
7+
#include "codegen_cpu.h"
8+
9+
#include "llvm/MC/MCSubtargetInfo.h"
10+
11+
namespace tvm {
12+
namespace codegen {
13+
14+
namespace {
15+
bool TargetHasFeature(const llvm::TargetMachine& tm, const std::string& feature) {
16+
// MCSubTargetInfo::checkFeatures was added in LLVM 6.0
17+
#if TVM_LLVM_VERSION >= 60
18+
const auto* MCInfo = tm.getMCSubtargetInfo();
19+
return MCInfo->checkFeatures(std::string("+") + feature);
20+
#else
21+
return false;
22+
// TODO(tulloch) - enable this block, need to figure out how to reimplement
23+
// this given visibility constraints, similar to
24+
// https://github.com/rust-lang/rust/pull/31709
25+
26+
// Copied from
27+
// https://github.com/llvm-mirror/llvm/blob/5136df4/lib/MC/MCSubtargetInfo.cpp#L78-L88.
28+
29+
// auto checkFeatures = [&](const std::string FS) {
30+
// llvm::SubtargetFeatures T(FS);
31+
// llvm::FeatureBitset Set, All;
32+
// for (std::string F : T.getFeatures()) {
33+
// llvm::SubtargetFeatures::ApplyFeatureFlag(Set, F, MCInfo->ProcFeatures);
34+
// if (F[0] == '-') {
35+
// F[0] = '+';
36+
// }
37+
// llvm::SubtargetFeatures::ApplyFeatureFlag(All, F, MCInfo->ProcFeatures);
38+
// }
39+
// return (MCInfo->getFeatureBits() & All) == Set;
40+
// };
41+
// return checkFeatures(MCInfo, std::string("+") + feature);
42+
#endif
43+
}
44+
} // namespace
45+
46+
class CodeGenX86_64 final : public CodeGenCPU {
47+
public:
48+
llvm::Value* VisitExpr_(const Cast* op) override;
49+
50+
private:
51+
llvm::Value* CallVectorIntrin(llvm::Intrinsic::ID id, size_t intrin_lanes, llvm::Type* result_ty,
52+
const std::vector<llvm::Value*>& args);
53+
};
54+
55+
llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) {
56+
// LLVM does not automatically generate the correct instruction sequences for
57+
// half -> float conversion (i.e. using AVX2/AVX-512 vectorized variants of
58+
// vcvtph2ps), so we explicitly generate them ourselves.
59+
const auto from = op->value.type();
60+
const auto to = op->type;
61+
if (from.is_float() && to.is_float() && from.bits() == 16 && to.bits() == 32) {
62+
CHECK_EQ(from.lanes(), to.lanes());
63+
CHECK_NOTNULL(target_machine_);
64+
65+
const auto has_f16c = TargetHasFeature(*target_machine_, "f16c");
66+
const auto has_avx512 = TargetHasFeature(*target_machine_, "avx512f");
67+
68+
if (from.lanes() >= 16 && has_avx512) {
69+
return CallVectorIntrin(
70+
::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, LLVMType(Float(32, from.lanes())),
71+
{
72+
MakeValue(ir::Call::make(Int(16, from.lanes()), ir::Call::reinterpret, {op->value},
73+
ir::Call::PureIntrinsic)),
74+
MakeValue(ir::Broadcast::make(ir::FloatImm::make(Float(32), 0), from.lanes())),
75+
/*mask=*/MakeValue(ir::IntImm::make(Int(16), -1)),
76+
/*rounding-mode=*/MakeValue(ir::IntImm::make(Int(32), 4)),
77+
});
78+
}
79+
80+
if (from.lanes() >= 8 && has_f16c) {
81+
return CallVectorIntrin(
82+
::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(Float(32, from.lanes())),
83+
{MakeValue(ir::Call::make(Int(16, from.lanes()), ir::Call::reinterpret, {op->value},
84+
ir::Call::PureIntrinsic))});
85+
}
86+
}
87+
88+
return CodeGenCPU::VisitExpr_(op);
89+
}
90+
91+
llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intrin_lanes,
92+
llvm::Type* result_ty,
93+
94+
const std::vector<llvm::Value*>& args) {
95+
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id, {});
96+
if (intrin_lanes == result_ty->getVectorNumElements()) {
97+
return builder_->CreateCall(f, args);
98+
}
99+
100+
// Otherwise, we split the vector into intrin_lanes sized elements (widening where necessary),
101+
// compute each result, and then concatenate the vectors (slicing the result if necessary).
102+
CHECK_LT(intrin_lanes, result_ty->getVectorNumElements());
103+
std::vector<llvm::Value*> split_results;
104+
for (size_t i = 0;
105+
i < static_cast<size_t>(result_ty->getVectorNumElements());
106+
i += intrin_lanes) {
107+
std::vector<llvm::Value*> split_args;
108+
for (const auto& v : args) {
109+
if (v->getType()->isVectorTy()) {
110+
CHECK_EQ(v->getType()->getVectorNumElements(), result_ty->getVectorNumElements());
111+
split_args.push_back(CreateVecSlice(v, i, intrin_lanes));
112+
} else {
113+
split_args.push_back(v);
114+
}
115+
}
116+
split_results.push_back(CallVectorIntrin(
117+
id, intrin_lanes, llvm::VectorType::get(result_ty->getScalarType(), intrin_lanes),
118+
split_args));
119+
}
120+
return CreateVecSlice(CreateVecConcat(split_results), 0, result_ty->getVectorNumElements());
121+
}
122+
123+
TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64")
124+
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
125+
CodeGenLLVM* cg = new CodeGenX86_64();
126+
*rv = static_cast<void*>(cg);
127+
});
128+
129+
} // namespace codegen
130+
} // namespace tvm
131+
#endif // TVM_LLVM_VERSION
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import tvm
2+
import re
3+
4+
5+
def test_fp16_to_fp32():
6+
if tvm.codegen.llvm_version_major() < 6:
7+
print("Skipping due to LLVM version being {} < 6".format(
8+
tvm.codegen.llvm_version_major()))
9+
return
10+
11+
def fp16_to_fp32(target, width, match=None, not_match=None):
12+
elements = 64
13+
n = tvm.convert(elements)
14+
A = tvm.placeholder((n, width), dtype="float16", name='A')
15+
B = tvm.compute(A.shape, lambda *i: A(*i).astype("float32"), name='B')
16+
s = tvm.create_schedule(B.op)
17+
s[B].vectorize(s[B].op.axis[1])
18+
f = tvm.build(s, [A, B], target)
19+
20+
assembly = f.get_source('asm').splitlines()
21+
if match:
22+
matches = [l for l in assembly if re.search(match, l)]
23+
assert matches
24+
if not_match:
25+
not_matches = [l for l in assembly if re.search(not_match, l)]
26+
assert not not_matches
27+
28+
29+
fp16_to_fp32(
30+
'llvm -mcpu=skylake-avx512', 15,
31+
match="vcvtph2ps.*ymm", not_match="vcvtph2ps.*zmm")
32+
fp16_to_fp32(
33+
'llvm -mcpu=skylake-avx512', 16,
34+
match="vcvtph2ps.*zmm")
35+
fp16_to_fp32(
36+
'llvm -mcpu=skylake-avx512', 17,
37+
match="vcvtph2ps.*zmm")
38+
fp16_to_fp32(
39+
'llvm -mcpu=skylake-avx512', 49,
40+
match="vcvtph2ps.*zmm")
41+
fp16_to_fp32(
42+
'llvm -mcpu=skylake-avx512 -mattr=-avx512f', 49,
43+
match="vcvtph2ps.*ymm",
44+
not_match="vcvtph2ps.*zmm")
45+
fp16_to_fp32(
46+
'llvm -mcpu=skylake-avx512 -mattr=-f16c,-avx512f', 49,
47+
not_match="vcvtph2ps")
48+
fp16_to_fp32(
49+
'llvm -mcpu=core-avx2', 8,
50+
match="vcvtph2ps.*ymm")
51+
fp16_to_fp32(
52+
'llvm -mcpu=core-avx2', 9,
53+
match="vcvtph2ps.*ymm")
54+
fp16_to_fp32(
55+
'llvm', 9,
56+
not_match="vcvtph2ps")
57+
58+
59+
if __name__ == "__main__":
60+
test_fp16_to_fp32()

0 commit comments

Comments
 (0)