Skip to content

Commit ae49faf

Browse files
committed
WIP: Implement function multi versioning in sysimg
1 parent 98c51e7 commit ae49faf

File tree

6 files changed

+314
-2
lines changed

6 files changed

+314
-2
lines changed

base/sysimg.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,24 @@ end
422422
INCLUDE_STATE = 3 # include = include_from_node1
423423
include("precompile.jl")
424424

425+
@noinline function test_clone_f(a)
426+
s = zero(eltype(a))
427+
@inbounds @simd for i in 1:length(a)
428+
s += a[i]
429+
end
430+
return s
431+
end
432+
433+
@noinline function test_clone_g(a, n)
434+
s = zero(eltype(a))
435+
for i in 1:n
436+
s += test_clone_f(a)
437+
end
438+
return s
439+
end
440+
441+
test_clone_g(Float64[], 1)
442+
425443
end # baremodule Base
426444

427445
using Base

src/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ endif
5454
LLVMLINK :=
5555

5656
ifeq ($(JULIACODEGEN),LLVM)
57-
SRCS += codegen jitlayers disasm debuginfo llvm-simdloop llvm-ptls llvm-gcroot llvm-lower-handlers cgmemmgr
57+
SRCS += codegen jitlayers disasm debuginfo llvm-simdloop llvm-ptls llvm-gcroot llvm-lower-handlers llvm-mv cgmemmgr
5858
FLAGS += -I$(shell $(LLVM_CONFIG_HOST) --includedir)
5959
LLVM_LIBS := all
6060
ifeq ($(USE_POLLY),1)

src/dump.c

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,26 @@ static void jl_load_sysimg_so(void)
244244
*sysimg_gvars[tls_offset_idx - 1] =
245245
(jl_value_t*)(uintptr_t)(jl_tls_offset == -1 ? 0 : jl_tls_offset);
246246
#endif
247+
#if defined(_CPU_X86_64_) || defined(_CPU_X86_)
248+
// WIP
249+
typedef void (*dispatch_t)(size_t, size_t*, void***, size_t**);
250+
dispatch_t dispatchf = (dispatch_t)jl_dlsym_e(jl_sysimg_handle,
251+
"jl_dispatch_sysimg_fvars");
252+
if (dispatchf) {
253+
size_t nfunc = 0;
254+
void **fptrs = NULL;
255+
size_t *fidxs = NULL;
256+
dispatchf(jl_test_cpu_feature(JL_X86_avx2) &&
257+
jl_test_cpu_feature(JL_X86_fma) &&
258+
jl_test_cpu_feature(JL_X86_popcnt), &nfunc, &fptrs, &fidxs);
259+
if (nfunc && fptrs && fidxs) {
260+
for (size_t i = 0; i < nfunc; i++) {
261+
size_t fi = fidxs[i];
262+
sysimg_fvars[fi] = fptrs[i];
263+
}
264+
}
265+
}
266+
#endif
247267

248268
#ifdef _OS_WINDOWS_
249269
sysimage_base = (intptr_t)jl_sysimg_handle;

src/jitlayers.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ void addOptimizationPasses(PassManager *PM)
177177
// Let the InstCombine pass remove the unnecessary load of
178178
// safepoint address first
179179
PM->add(createLowerPTLSPass(imaging_mode));
180+
PM->add(createJuliaMVPass());
180181
PM->add(createSROAPass()); // Break up aggregate allocas
181182
#ifndef INSTCOMBINE_BUG
182183
PM->add(createInstructionCombiningPass()); // Cleanup for scalarrepl.
@@ -1094,7 +1095,7 @@ static void jl_gen_llvm_globaldata(llvm::Module *mod, ValueToValueMapTy &VMap,
10941095
ArrayType *fvars_type = ArrayType::get(T_pvoidfunc, jl_sysimg_fvars.size());
10951096
addComdat(new GlobalVariable(*mod,
10961097
fvars_type,
1097-
true,
1098+
false,
10981099
GlobalVariable::ExternalLinkage,
10991100
MapValue(ConstantArray::get(fvars_type, ArrayRef<Constant*>(jl_sysimg_fvars)), VMap),
11001101
"jl_sysimg_fvars"));

src/jitlayers.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ JL_DLLEXPORT extern LLVMContext &jl_LLVMContext;
249249
Pass *createLowerPTLSPass(bool imaging_mode);
250250
Pass *createLowerGCFramePass();
251251
Pass *createLowerExcHandlersPass();
252+
Pass *createJuliaMVPass();
252253
// Whether the Function is an llvm or julia intrinsic.
253254
static inline bool isIntrinsicFunction(Function *F)
254255
{

src/llvm-mv.cpp

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
// This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
// Function multi-versioning
4+
#define DEBUG_TYPE "julia_mv"
5+
#undef DEBUG
6+
7+
// LLVM pass to clone function for different archs
8+
9+
#include "llvm-version.h"
10+
#include "support/dtypes.h"
11+
12+
#include <llvm/Pass.h>
13+
#include <llvm/IR/Module.h>
14+
#include <llvm/IR/Function.h>
15+
#include <llvm/IR/Instructions.h>
16+
#include <llvm/IR/Constants.h>
17+
#include <llvm/IR/LLVMContext.h>
18+
#include <llvm/Analysis/LoopInfo.h>
19+
#if JL_LLVM_VERSION >= 30700
20+
#include <llvm/IR/LegacyPassManager.h>
21+
#else
22+
#include <llvm/PassManager.h>
23+
#endif
24+
#include <llvm/IR/MDBuilder.h>
25+
#include <llvm/IR/IRBuilder.h>
26+
#include <llvm/Transforms/Utils/Cloning.h>
27+
#include "fix_llvm_assert.h"
28+
29+
#include "julia.h"
30+
#include "julia_internal.h"
31+
32+
#include <unordered_map>
33+
#include <vector>
34+
35+
using namespace llvm;
36+
37+
extern std::pair<MDNode*,MDNode*> tbaa_make_child(const char *name, MDNode *parent=nullptr, bool isConstant=false);
38+
extern "C" void jl_dump_llvm_value(void *v);
39+
40+
namespace {
41+
42+
struct JuliaMV: public ModulePass {
43+
static char ID;
44+
JuliaMV()
45+
: ModulePass(ID)
46+
{}
47+
48+
private:
49+
bool runOnModule(Module &M) override;
50+
void getAnalysisUsage(AnalysisUsage &AU) const override
51+
{
52+
AU.addRequired<LoopInfoWrapperPass>();
53+
AU.setPreservesAll();
54+
}
55+
bool shouldClone(Function &F);
56+
bool checkUses(Function &F, Constant *fary);
57+
bool checkUses(Function &F, Constant *V, Constant *fary, bool &inFVars);
58+
bool checkConstantUse(Function &F, Constant *V, Constant *fary, bool &inFVars);
59+
};
60+
61+
bool JuliaMV::shouldClone(Function &F)
62+
{
63+
if (F.empty())
64+
return false;
65+
auto &LI = getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo();
66+
if (!LI.empty())
67+
return true;
68+
for (auto &bb: F) {
69+
for (auto &I: bb) {
70+
if (auto call = dyn_cast<CallInst>(&I)) {
71+
if (auto callee = call->getCalledFunction()) {
72+
auto name = callee->getName();
73+
if (name.startswith("llvm.muladd.") || name.startswith("llvm.fma.")) {
74+
return true;
75+
}
76+
}
77+
}
78+
}
79+
}
80+
return false;
81+
}
82+
83+
bool JuliaMV::checkUses(Function &F, Constant *fary)
84+
{
85+
bool inFVars = false;
86+
bool res = checkUses(F, &F, fary, inFVars);
87+
return res && inFVars;
88+
}
89+
90+
bool JuliaMV::checkConstantUse(Function &F, Constant *V, Constant *fary, bool &inFVars)
91+
{
92+
if (V == fary) {
93+
inFVars = true;
94+
return true;
95+
}
96+
if (auto cexpr = dyn_cast<ConstantExpr>(V)) {
97+
if (cexpr->getOpcode() == Instruction::BitCast) {
98+
return checkUses(F, V, fary, inFVars);
99+
}
100+
}
101+
return false;
102+
}
103+
104+
bool JuliaMV::checkUses(Function &F, Constant *V, Constant *fary, bool &inFVars)
105+
{
106+
for (auto *user: V->users()) {
107+
if (isa<Instruction>(user))
108+
continue;
109+
auto *C = dyn_cast<Constant>(user);
110+
if (!C || !checkConstantUse(F, C, fary, inFVars)) {
111+
return false;
112+
}
113+
}
114+
return true;
115+
}
116+
117+
static Function *getFunction(Value *v)
118+
{
119+
if (auto f = dyn_cast<Function>(v))
120+
return f;
121+
if (auto c = dyn_cast<ConstantExpr>(v)) {
122+
if (c->getOpcode() == Instruction::BitCast) {
123+
return getFunction(c->getOperand(0));
124+
}
125+
}
126+
return nullptr;
127+
}
128+
129+
static void addFeatures(Function *F)
130+
{
131+
auto attr = F->getFnAttribute("target-features");
132+
std::string feature =
133+
"+avx2,+avx,+fma,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3";
134+
if (attr.isStringAttribute()) {
135+
feature += ",";
136+
feature += attr.getValueAsString();
137+
}
138+
F->addFnAttr("target-features", feature);
139+
}
140+
141+
bool JuliaMV::runOnModule(Module &M)
142+
{
143+
MDNode *tbaa_const = tbaa_make_child("jtbaa_const", nullptr, true).first;
144+
GlobalVariable *fvars = M.getGlobalVariable("jl_sysimg_fvars");
145+
// This makes sure this only runs during sysimg generation
146+
if (!fvars || !fvars->hasInitializer())
147+
return true;
148+
auto *fary = dyn_cast<ConstantArray>(fvars->getInitializer());
149+
if (!fary)
150+
return true;
151+
LLVMContext &ctx = M.getContext();
152+
ValueToValueMapTy VMap;
153+
for (auto &F: M) {
154+
if (shouldClone(F) && checkUses(F, fary)) {
155+
Function *NF = Function::Create(cast<FunctionType>(F.getValueType()),
156+
F.getLinkage(), F.getName() + ".avx2", &M);
157+
NF->copyAttributesFrom(&F);
158+
VMap[&F] = NF;
159+
}
160+
}
161+
std::unordered_map<Function*,size_t> idx_map;
162+
size_t nf = fary->getNumOperands();
163+
for (size_t i = 0; i < nf; i++) {
164+
if (Function *ele = getFunction(fary->getOperand(i))) {
165+
auto it = VMap.find(ele);
166+
if (it != VMap.end()) {
167+
idx_map[ele] = i;
168+
}
169+
}
170+
}
171+
for (auto I: idx_map) {
172+
auto oldF = I.first;
173+
auto newF = cast<Function>(VMap[oldF]);
174+
Function::arg_iterator DestI = newF->arg_begin();
175+
for (Function::const_arg_iterator J = oldF->arg_begin(); J != oldF->arg_end(); ++J) {
176+
DestI->setName(J->getName());
177+
VMap[&*J] = &*DestI++;
178+
}
179+
SmallVector<ReturnInst*,8> Returns;
180+
CloneFunctionInto(newF, oldF, VMap, false, Returns);
181+
addFeatures(newF);
182+
}
183+
std::vector<Constant*> ptrs;
184+
std::vector<Constant*> idxs;
185+
auto T_void = Type::getVoidTy(ctx);
186+
auto T_pvoidfunc = FunctionType::get(T_void, false)->getPointerTo();
187+
auto T_size = (sizeof(size_t) == 8 ? Type::getInt64Ty(ctx) : Type::getInt32Ty(ctx));
188+
for (auto I: idx_map) {
189+
auto oldF = I.first;
190+
auto idx = I.second;
191+
auto newF = cast<Function>(VMap[oldF]);
192+
ptrs.push_back(ConstantExpr::getBitCast(newF, T_pvoidfunc));
193+
auto offset = ConstantInt::get(T_size, idx);
194+
idxs.push_back(offset);
195+
for (auto user: oldF->users()) {
196+
auto inst = dyn_cast<Instruction>(user);
197+
if (!inst)
198+
continue;
199+
auto encloseF = inst->getParent()->getParent();
200+
if (VMap.find(encloseF) != VMap.end())
201+
continue;
202+
Value *slot = ConstantExpr::getBitCast(fvars, T_pvoidfunc->getPointerTo());
203+
slot = GetElementPtrInst::Create(T_pvoidfunc, slot, {offset}, "", inst);
204+
Instruction *ptr = new LoadInst(slot, "", inst);
205+
ptr->setMetadata(llvm::LLVMContext::MD_tbaa, tbaa_const);
206+
ptr = new BitCastInst(ptr, oldF->getType(), "", inst);
207+
inst->replaceUsesOfWith(oldF, ptr);
208+
}
209+
}
210+
ArrayType *fvars_type = ArrayType::get(T_pvoidfunc, ptrs.size());
211+
auto ptr_gv = new GlobalVariable(M, fvars_type, true, GlobalVariable::InternalLinkage,
212+
ConstantArray::get(fvars_type, ptrs));
213+
ArrayType *idxs_type = ArrayType::get(T_size, idxs.size());
214+
auto idx_gv = new GlobalVariable(M, idxs_type, true, GlobalVariable::InternalLinkage,
215+
ConstantArray::get(idxs_type, idxs));
216+
217+
// TODO
218+
std::vector<Type*> dispatch_args(0);
219+
dispatch_args.push_back(T_size); // hasavx2
220+
dispatch_args.push_back(T_size->getPointerTo());
221+
dispatch_args.push_back(fvars_type->getPointerTo()->getPointerTo());
222+
dispatch_args.push_back(idxs_type->getPointerTo()->getPointerTo());
223+
Function *dispatchF = Function::Create(FunctionType::get(T_void, dispatch_args, false),
224+
Function::ExternalLinkage,
225+
"jl_dispatch_sysimg_fvars", &M);
226+
IRBuilder<> builder(ctx);
227+
BasicBlock *b0 = BasicBlock::Create(ctx, "top", dispatchF);
228+
builder.SetInsertPoint(b0);
229+
DebugLoc noDbg;
230+
builder.SetCurrentDebugLocation(noDbg);
231+
232+
std::vector<Argument*> args;
233+
for (auto &arg: dispatchF->args())
234+
args.push_back(&arg);
235+
236+
auto sz_arg = args[1];
237+
auto fvars_arg = args[2];
238+
auto idxs_arg = args[3];
239+
240+
builder.CreateStore(ConstantInt::get(T_size, ptrs.size()), sz_arg);
241+
242+
BasicBlock *match_bb = BasicBlock::Create(ctx, "match");
243+
BasicBlock *fail_bb = BasicBlock::Create(ctx, "fail");
244+
builder.CreateCondBr(builder.CreateICmpEQ(args[0], ConstantInt::get(T_size, 1)),
245+
match_bb, fail_bb);
246+
247+
dispatchF->getBasicBlockList().push_back(match_bb);
248+
builder.SetInsertPoint(match_bb);
249+
builder.CreateStore(ptr_gv, fvars_arg);
250+
builder.CreateStore(idx_gv, idxs_arg);
251+
builder.CreateRetVoid();
252+
253+
dispatchF->getBasicBlockList().push_back(fail_bb);
254+
builder.SetInsertPoint(fail_bb);
255+
builder.CreateStore(ConstantPointerNull::get(fvars_type->getPointerTo()), fvars_arg);
256+
builder.CreateStore(ConstantPointerNull::get(idxs_type->getPointerTo()), idxs_arg);
257+
builder.CreateRetVoid();
258+
259+
return true;
260+
}
261+
262+
char JuliaMV::ID = 0;
263+
static RegisterPass<JuliaMV> X("JuliaMV", "JuliaMV Pass",
264+
false /* Only looks at CFG */,
265+
false /* Analysis Pass */);
266+
267+
}
268+
269+
Pass *createJuliaMVPass()
270+
{
271+
return new JuliaMV();
272+
}

0 commit comments

Comments
 (0)