Skip to content

Commit 50798e4

Browse files
committed
Fix vcvtph2ps codegen
1 parent 4ac64fc commit 50798e4

File tree

3 files changed

+167
-16
lines changed

3 files changed

+167
-16
lines changed

src/codegen/llvm/codegen_llvm.cc

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
*/
55
#ifdef TVM_LLVM_VERSION
66
// Part of the code are adapted from Halide's CodeGen_LLVM
7+
#include <algorithm>
78

89
#include <tvm/runtime/device_api.h>
910
#include <tvm/runtime/c_runtime_api.h>
@@ -410,12 +411,16 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
410411
llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
411412
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
412413
if (extent == num_elems && begin == 0) return vec;
413-
CHECK_LE(begin + extent, num_elems);
414-
std::vector<unsigned> indices;
414+
std::vector<llvm::Constant*> indices;
415+
indices.reserve(extent);
415416
for (int i = 0; i < extent; ++i) {
416-
indices.push_back(begin + i);
417+
if (begin + i >= 0 && begin + i < num_elems) {
418+
indices.push_back(llvm::ConstantInt::get(t_int32_, begin + i));
419+
} else {
420+
indices.push_back(llvm::UndefValue::get(t_int32_));
421+
}
417422
}
418-
return builder_->CreateShuffleVector(vec, vec, indices);
423+
return builder_->CreateShuffleVector(vec, vec, llvm::ConstantVector::get(indices));
419424
}
420425

421426
llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
@@ -446,24 +451,31 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
446451
v->getType()->getVectorNumElements());
447452
}
448453
while (vecs.size() > 1) {
449-
for (size_t i = 0; i < vecs.size(); i+=2) {
450-
if (i + 1 >= vecs.size()) {
451-
vecs[i / 2] = vecs[i]; continue;
452-
}
454+
std::vector<llvm::Value*> new_vecs;
455+
for (size_t i = 0; i < vecs.size() - 1; i += 2) {
453456
llvm::Value* lhs = vecs[i];
454457
llvm::Value* rhs = vecs[i + 1];
455-
int lanes = static_cast<int>(std::max(
456-
lhs->getType()->getVectorNumElements(),
457-
rhs->getType()->getVectorNumElements()));
458-
lhs = CreateVecPad(lhs, lanes);
459-
rhs = CreateVecPad(lhs, lanes);
458+
const auto lhs_lanes = lhs->getType()->getVectorNumElements();
459+
const auto rhs_lanes = rhs->getType()->getVectorNumElements();
460+
if (lhs_lanes < rhs_lanes) {
461+
lhs = CreateVecPad(lhs, rhs_lanes);
462+
} else if (rhs_lanes < lhs_lanes) {
463+
rhs = CreateVecPad(rhs, lhs_lanes);
464+
}
465+
const auto shared_lanes = std::max(lhs_lanes, rhs_lanes);
460466
std::vector<unsigned> mask;
461-
for (int i = 0; i < lanes * 2; ++i) {
467+
for (int i = 0; i < lhs_lanes; ++i) {
462468
mask.push_back(i);
463469
}
464-
vecs[i / 2] = builder_->CreateShuffleVector(lhs, rhs, mask);
470+
for (int i = 0; i < rhs_lanes; ++i) {
471+
mask.push_back(shared_lanes + i);
472+
}
473+
new_vecs.push_back(builder_->CreateShuffleVector(lhs, rhs, mask));
474+
}
475+
if (vecs.size() % 2 != 0) {
476+
new_vecs.push_back(vecs.back());
465477
}
466-
vecs.resize((vecs.size() + 1) / 2);
478+
vecs.swap(new_vecs);
467479
}
468480
return CreateVecSlice(vecs[0], 0, total_lanes);
469481
}

src/codegen/llvm/codegen_x86_64.cc

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*!
2+
* Copyright (c) 2019 by Contributors
3+
* \file codegen_x86_64.cc
4+
* \brief X86-64 specific code generator
5+
*/
6+
#ifdef TVM_LLVM_VERSION
7+
#include "codegen_cpu.h"
8+
9+
#include "llvm/MC/MCSubtargetInfo.h"
10+
11+
namespace tvm {
12+
namespace codegen {
13+
14+
namespace {
15+
bool TargetHasFeature(const llvm::TargetMachine& tm, const std::string& feature) {
16+
const auto* MCInfo = tm.getMCSubtargetInfo();
17+
return MCInfo->checkFeatures(std::string("+") + feature);
18+
}
19+
} // namespace
20+
21+
class CodeGenX86_64 final : public CodeGenCPU {
22+
public:
23+
llvm::Value* VisitExpr_(const Cast* op) override;
24+
25+
private:
26+
llvm::Value* CallVectorIntrin(llvm::Intrinsic::ID id, size_t intrin_lanes, llvm::Type* result_ty,
27+
const std::vector<llvm::Value*>& args);
28+
};
29+
30+
llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) {
31+
// LLVM does not automatically generate the correct instruction sequences for
32+
// half -> float conversion (i.e. using AVX2/AVX-512 vectorized variants of
33+
// vcvtph2ps), so we explicitly generate them ourselves.
34+
const auto from = op->value.type();
35+
const auto to = op->type;
36+
if (from.is_float() && to.is_float() && from.bits() == 16 && to.bits() == 32) {
37+
CHECK_EQ(from.lanes(), to.lanes());
38+
CHECK_NOTNULL(target_machine_);
39+
40+
const auto has_f16c = TargetHasFeature(*target_machine_, "f16c");
41+
const auto has_avx512 = TargetHasFeature(*target_machine_, "avx512f");
42+
43+
if (from.lanes() >= 16 && has_avx512) {
44+
return CallVectorIntrin(
45+
::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, LLVMType(Float(32, from.lanes())),
46+
{
47+
MakeValue(ir::Call::make(Int(16, from.lanes()), ir::Call::reinterpret, {op->value},
48+
ir::Call::PureIntrinsic)),
49+
MakeValue(ir::Broadcast::make(ir::FloatImm::make(Float(32), 0), from.lanes())),
50+
/*mask=*/MakeValue(ir::IntImm::make(Int(16), -1)),
51+
/*rounding-mode=*/MakeValue(ir::IntImm::make(Int(32), 4)),
52+
});
53+
}
54+
55+
if (from.lanes() >= 8 && has_f16c) {
56+
return CallVectorIntrin(
57+
::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(Float(32, from.lanes())),
58+
{MakeValue(ir::Call::make(Int(16, from.lanes()), ir::Call::reinterpret, {op->value},
59+
ir::Call::PureIntrinsic))});
60+
}
61+
}
62+
63+
return CodeGenCPU::VisitExpr_(op);
64+
}
65+
66+
llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intrin_lanes,
67+
llvm::Type* result_ty,
68+
69+
const std::vector<llvm::Value*>& args) {
70+
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id, {});
71+
if (intrin_lanes == result_ty->getVectorNumElements()) {
72+
return builder_->CreateCall(f, args);
73+
}
74+
75+
// Otherwise, we split the vector into intrin_lanes sized elements (widening where necessary),
76+
// compute each result, and then concatenate the vectors (slicing the result if necessary).
77+
CHECK_LT(intrin_lanes, result_ty->getVectorNumElements());
78+
std::vector<llvm::Value*> split_results;
79+
for (auto i = 0; i < result_ty->getVectorNumElements(); i += intrin_lanes) {
80+
std::vector<llvm::Value*> split_args;
81+
for (const auto& v : args) {
82+
if (v->getType()->isVectorTy()) {
83+
CHECK_EQ(v->getType()->getVectorNumElements(), result_ty->getVectorNumElements());
84+
split_args.push_back(CreateVecSlice(v, i, intrin_lanes));
85+
} else {
86+
split_args.push_back(v);
87+
}
88+
}
89+
split_results.push_back(CallVectorIntrin(
90+
id, intrin_lanes, llvm::VectorType::get(result_ty->getScalarType(), intrin_lanes),
91+
split_args));
92+
}
93+
return CreateVecSlice(CreateVecConcat(split_results), 0, result_ty->getVectorNumElements());
94+
}
95+
96+
TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64")
97+
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
98+
CodeGenLLVM* cg = new CodeGenX86_64();
99+
*rv = static_cast<void*>(cg);
100+
});
101+
102+
} // namespace codegen
103+
} // namespace tvm
104+
#endif // TVM_LLVM_VERSION
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import tvm
2+
import re
3+
import os
4+
import ctypes
5+
6+
def test_fp16_to_fp32():
7+
def f(target, width, match=None, not_match=None):
8+
elements = 64
9+
n = tvm.convert(elements)
10+
A = tvm.placeholder((n, width), dtype="float16", name='A')
11+
B = tvm.compute(A.shape, lambda *i: A(*i).astype("float32"), name='B')
12+
s = tvm.create_schedule(B.op)
13+
s[B].vectorize(s[B].op.axis[1])
14+
f = tvm.build(s, [A, B], target)
15+
16+
assembly = f.get_source('asm').splitlines()
17+
if match:
18+
matches = [l for l in assembly if re.search(match, l)]
19+
assert matches
20+
if not_match:
21+
not_matches = [l for l in assembly if re.search(not_match, l)]
22+
assert not not_matches
23+
24+
25+
f(target='llvm -mcpu=skylake-avx512', width=15, match="vcvtph2ps.*ymm", not_match="vcvtph2ps.*zmm")
26+
f(target='llvm -mcpu=skylake-avx512', width=16, match="vcvtph2ps.*zmm")
27+
f(target='llvm -mcpu=skylake-avx512', width=17, match="vcvtph2ps.*zmm")
28+
f(target='llvm -mcpu=skylake-avx512', width=49, match="vcvtph2ps.*zmm")
29+
f(target='llvm -mcpu=core-avx2', width=8, match="vcvtph2ps.*ymm")
30+
f(target='llvm -mcpu=core-avx2', width=9, match="vcvtph2ps.*ymm")
31+
f(target='llvm', width=9, not_match="vcvtph2ps")
32+
33+
34+
if __name__ == "__main__":
35+
test_fp16_to_fp32()

0 commit comments

Comments
 (0)