Skip to content

Commit b7377a2

Browse files
committed
working memcpy on floats
1 parent 5fa4a24 commit b7377a2

File tree

7 files changed

+287
-79
lines changed

7 files changed

+287
-79
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.swp

bench/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
%-unopt.ll: %.c
66
#../build/bin/clang -fno-unroll-loops $^ -O1 -fno-vectorize -fno-slp-vectorize -S -emit-llvm -o - -ffast-math | ../build/bin/opt -early-cse-memssa - -S -o $@ -simplifycfg -simplifycfg
77
#../build/bin/clang -fno-unroll-loops $^ -O3 -fno-vectorize -fno-slp-vectorize -S -emit-llvm -o - -ffast-math | ../build/bin/opt -early-cse-memssa - -S -o $@ -simplifycfg -simplifycfg
8-
../build/bin/clang -fno-unroll-loops $^ -O0 -fno-vectorize -fno-slp-vectorize -S -emit-llvm -o - -ffast-math -o $@ #| ../build/bin/opt -early-cse-memssa - -S -o $@ -simplifycfg -simplifycfg
8+
../build/bin/clang -fno-unroll-loops $^ -O3 -fno-vectorize -fno-slp-vectorize -S -emit-llvm -o - -ffast-math -o $@ #| ../build/bin/opt -early-cse-memssa - -S -o $@ -simplifycfg -simplifycfg
99

1010
%-unopt.ll: %.cpp
1111
#../build/bin/clang++ -fno-unroll-loops $^ -O3 -fno-vectorize -fno-slp-vectorize -S -emit-llvm -o $@ -ffast-math -Wno-error=non-pod-varargs -DEIGEN_UNROLLING_LIMIT=0 -I ../../adept-2.0.5/include -I../../tapenade/ADFirstAidKit

enzyme/Enzyme/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
if (${LLVM_VERSION_MAJOR} LESS 8)
66
add_llvm_loadable_module( LLVMEnzyme-${LLVM_VERSION_MAJOR}
77
Enzyme.cpp
8+
Utils.cpp
89
SCEV/ScalarEvolutionExpander.cpp
910
DEPENDS
1011
intrinsics_gen
@@ -14,6 +15,7 @@ if (${LLVM_VERSION_MAJOR} LESS 8)
1415
else()
1516
add_llvm_library( LLVMEnzyme-${LLVM_VERSION_MAJOR}
1617
Enzyme.cpp
18+
Utils.cpp
1719
SCEV/ScalarEvolutionExpander.cpp
1820
MODULE
1921
DEPENDS

enzyme/Enzyme/Enzyme.cpp

Lines changed: 43 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include <llvm/Config/llvm-config.h>
2020

21+
#include "Utils.h"
2122
#include "SCEV/ScalarEvolutionExpander.h"
2223

2324
#include "llvm/Transforms/Utils/PromoteMemToReg.h"
@@ -141,12 +142,6 @@ std::string tostring(DIFFE_TYPE t) {
141142
}
142143
}
143144

144-
static inline FastMathFlags getFast() {
145-
FastMathFlags f;
146-
f.set();
147-
return f;
148-
}
149-
150145
Instruction *getNextNonDebugInstruction(Instruction* Z) {
151146
for (Instruction *I = Z->getNextNode(); I; I = I->getNextNode())
152147
if (!isa<DbgInfoIntrinsic>(I))
@@ -349,21 +344,22 @@ bool isIntASecretFloat(Value* val) {
349344
assert(0 && "unsure if constant or not");
350345
}
351346

352-
bool isIntPointerASecretFloat(Value* val) {
347+
//! return the secret float type if found, otherwise nullptr
348+
Type* isIntPointerASecretFloat(Value* val) {
353349
assert(val->getType()->isPointerTy());
354350
assert(cast<PointerType>(val->getType())->getElementType()->isIntegerTy());
355351

356-
if (isa<UndefValue>(val)) return true;
352+
if (isa<UndefValue>(val)) return nullptr;
357353

358354
if (auto cint = dyn_cast<ConstantInt>(val)) {
359-
if (!cint->isZero()) return false;
355+
if (!cint->isZero()) return nullptr;
360356
assert(0 && "unsure if constant or not because constantint");
361357
//if (cint->isOne()) return cint;
362358
}
363359

364360

365361
if (auto inst = dyn_cast<Instruction>(val)) {
366-
bool floatingUse = false;
362+
Type* floatingUse = nullptr;
367363
bool pointerUse = false;
368364
SmallPtrSet<Value*, 4> seen;
369365

@@ -374,7 +370,11 @@ bool isIntPointerASecretFloat(Value* val) {
374370
do {
375371
Type* let = cast<PointerType>(v->getType())->getElementType();
376372
if (let->isFloatingPointTy()) {
377-
floatingUse = true;
373+
if (floatingUse == nullptr) {
374+
floatingUse = let;
375+
} else {
376+
assert(floatingUse == let);
377+
}
378378
}
379379
if (auto ci = dyn_cast<CastInst>(v)) {
380380
if (auto cal = dyn_cast<CallInst>(ci->getOperand(0))) {
@@ -409,7 +409,11 @@ bool isIntPointerASecretFloat(Value* val) {
409409
llvm::errs() << " for val " << *v << *et << "\n";
410410

411411
if (et->isFloatingPointTy()) {
412-
floatingUse = true;
412+
if (floatingUse == nullptr) {
413+
floatingUse = et;
414+
} else {
415+
assert(floatingUse == et);
416+
}
413417
}
414418
if (et->isPointerTy()) {
415419
pointerUse = true;
@@ -431,8 +435,8 @@ bool isIntPointerASecretFloat(Value* val) {
431435
}
432436
}
433437

434-
if (pointerUse && !floatingUse) return false;
435-
if (!pointerUse && floatingUse) return true;
438+
if (pointerUse && (floatingUse == nullptr)) return nullptr;
439+
if (!pointerUse && (floatingUse != nullptr)) return floatingUse;
436440
llvm::errs() << *inst->getParent()->getParent() << "\n";
437441
llvm::errs() << " val:" << *val << " pointer:" << pointerUse << " floating:" << floatingUse << "\n";
438442
assert(0 && "ambiguous unsure if constant or not");
@@ -894,6 +898,7 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI)
894898
nullptr);
895899
NewF->setAttributes(F->getAttributes());
896900

901+
if (enzyme_preopt) {
897902
{
898903
FunctionAnalysisManager AM;
899904
AM.registerPass([] { return LoopAnalysis(); });
@@ -908,8 +913,6 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI)
908913

909914
}
910915

911-
if (enzyme_preopt) {
912-
913916
if(autodiff_inline) {
914917
llvm::errs() << "running inlining process\n";
915918
forceRecursiveInlining(NewF, F);
@@ -1098,52 +1101,32 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI)
10981101
LoopAnalysisManager LAM;
10991102
AM.registerPass([&] { return LoopAnalysisManagerFunctionProxy(LAM); });
11001103
LAM.registerPass([&] { return FunctionAnalysisManagerLoopProxy(AM); });
1101-
1102-
SimplifyCFGOptions scfgo(/*unsigned BonusThreshold=*/1, /*bool ForwardSwitchCond=*/false, /*bool SwitchToLookup=*/false, /*bool CanonicalLoops=*/true, /*bool SinkCommon=*/true, /*AssumptionCache *AssumpCache=*/nullptr);
1103-
SimplifyCFGPass(scfgo).run(*NewF, AM);
1104-
LoopSimplifyPass().run(*NewF, AM);
1105-
1106-
if (autodiff_inline) {
1107-
createFunctionToLoopPassAdaptor(LoopIdiomRecognizePass()).run(*NewF, AM);
1108-
}
1109-
DSEPass().run(*NewF, AM);
1110-
LoopSimplifyPass().run(*NewF, AM);
1111-
1112-
}
1113-
}
1114-
1115-
{
1116-
FunctionAnalysisManager AM;
1117-
AM.registerPass([] { return AAManager(); });
1118-
AM.registerPass([] { return ScalarEvolutionAnalysis(); });
1119-
AM.registerPass([] { return AssumptionAnalysis(); });
1120-
AM.registerPass([] { return TargetLibraryAnalysis(); });
1121-
AM.registerPass([] { return TargetIRAnalysis(); });
1122-
AM.registerPass([] { return LoopAnalysis(); });
1123-
AM.registerPass([] { return MemorySSAAnalysis(); });
1124-
AM.registerPass([] { return DominatorTreeAnalysis(); });
1125-
AM.registerPass([] { return MemoryDependenceAnalysis(); });
1126-
#if LLVM_VERSION_MAJOR > 6
1127-
AM.registerPass([] { return PhiValuesAnalysis(); });
1128-
#endif
1129-
#if LLVM_VERSION_MAJOR >= 8
1130-
AM.registerPass([] { return PassInstrumentationAnalysis(); });
1131-
#endif
11321104

11331105
ModuleAnalysisManager MAM;
11341106
AM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); });
11351107
MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(AM); });
11361108

1109+
SimplifyCFGOptions scfgo(/*unsigned BonusThreshold=*/1, /*bool ForwardSwitchCond=*/false, /*bool SwitchToLookup=*/false, /*bool CanonicalLoops=*/true, /*bool SinkCommon=*/true, /*AssumptionCache *AssumpCache=*/nullptr);
1110+
SimplifyCFGPass(scfgo).run(*NewF, AM);
1111+
LoopSimplifyPass().run(*NewF, AM);
1112+
1113+
//AAManager().run(*NewF, AM)
11371114
BasicAA ba;
11381115
auto baa = new BasicAAResult(ba.run(*NewF, AM));
11391116
AA.addAAResult(*baa);
11401117

11411118
ScopedNoAliasAA sa;
11421119
auto saa = new ScopedNoAliasAAResult(sa.run(*NewF, AM));
11431120
AA.addAAResult(*saa);
1144-
1121+
if (autodiff_inline) {
1122+
createFunctionToLoopPassAdaptor(LoopIdiomRecognizePass()).run(*NewF, AM);
11451123
}
1124+
DSEPass().run(*NewF, AM);
1125+
LoopSimplifyPass().run(*NewF, AM);
11461126

1127+
}
1128+
}
1129+
11471130
if (autodiff_print)
11481131
llvm::errs() << "after simplification :\n" << *NewF << "\n";
11491132

@@ -3983,11 +3966,10 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
39833966
tbuild.SetInsertPoint(&gutils->reverseBlocks[loopContext.exit]->back());
39843967
}
39853968

3969+
loopContext.antivar->addIncoming(gutils->lookupM(loopContext.limit, tbuild), gutils->reverseBlocks[loopContext.exit]);
39863970
auto sub = Builder2.CreateSub(loopContext.antivar, ConstantInt::get(loopContext.antivar->getType(), 1));
39873971
for(BasicBlock* in: successors(loopContext.latch) ) {
3988-
if (loopContext.exit == in) {
3989-
loopContext.antivar->addIncoming(gutils->lookupM(loopContext.limit, tbuild), gutils->reverseBlocks[in]);
3990-
} else if (gutils->LI.getLoopFor(in) == gutils->LI.getLoopFor(BB)) {
3972+
if (gutils->LI.getLoopFor(in) == gutils->LI.getLoopFor(BB)) {
39913973
loopContext.antivar->addIncoming(sub, gutils->reverseBlocks[in]);
39923974
}
39933975
}
@@ -4057,7 +4039,16 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
40574039
switch(op->getIntrinsicID()) {
40584040
case Intrinsic::memcpy: {
40594041
if (gutils->isConstantInstruction(inst)) continue;
4060-
if (!isIntPointerASecretFloat(op->getOperand(0)) ) {
4042+
if (Type* secretty = isIntPointerASecretFloat(op->getOperand(0)) ) {
4043+
SmallVector<Value*, 4> args;
4044+
auto secretpt = PointerType::getUnqual(secretty);
4045+
4046+
args.push_back(Builder2.CreatePointerCast(invertPointer(op->getOperand(0)), secretpt));
4047+
args.push_back(Builder2.CreatePointerCast(invertPointer(op->getOperand(1)), secretpt));
4048+
args.push_back(lookup(op->getOperand(2)));
4049+
auto dmemcpy = getOrInsertDifferentialFloatMemcpy(*M, secretpt);
4050+
auto cal = Builder2.CreateCall(dmemcpy, args);
4051+
} else {
40614052
if (topLevel) {
40624053
SmallVector<Value*, 4> args;
40634054
IRBuilder <>BuilderZ(op);
@@ -4072,32 +4063,6 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
40724063
cal->setCallingConv(op->getCallingConv());
40734064
cal->setTailCallKind(op->getTailCallKind());
40744065
}
4075-
} else {
4076-
//no change to forward pass if represents float
4077-
//Zero the destination
4078-
assert(0 && "TODO: memcpy has bug that needs fixing (per int double vs ptr)");
4079-
/*
4080-
{
4081-
TODO BECOME MEMSET
4082-
SmallVector<Value*, 4> args;
4083-
// source and dest are swapped
4084-
args.push_back(invertPointer(op->getOperand(1)));
4085-
args.push_back(invertPointer(op->getOperand(0)));
4086-
args.push_back(lookup(op->getOperand(2)));
4087-
args.push_back(lookup(op->getOperand(3)));
4088-
4089-
Type *tys[] = {args[0]->getType(), args[1]->getType(), args[2]->getType()};
4090-
auto cal = Builder2.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::memset, tys), args);
4091-
cal->setAttributes(op->getAttributes());
4092-
cal->setCallingConv(op->getCallingConv());
4093-
cal->setTailCallKind(op->getTailCallKind());
4094-
}
4095-
4096-
4097-
{
4098-
4099-
}
4100-
*/
41014066
}
41024067
break;
41034068
}

enzyme/Enzyme/Utils.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Utils.cpp
3+
*
4+
* Copyright (C) 2019 William S. Moses (enzyme@wsmoses.com) - All Rights Reserved
5+
*
6+
* For commercial use of this code please contact the author(s) above.
7+
*
8+
* For research use of the code please use the following citation.
9+
*
10+
* \misc{mosesenzyme,
11+
author = {William S. Moses, Tim Kaler},
12+
title = {Enzyme: LLVM Automatic Differentiation},
13+
year = {2019},
14+
howpublished = {\url{https://github.com/wsmoses/Enzyme/}},
15+
note = {commit xxxxxxx}
16+
*/
17+
18+
#include "Utils.h"
19+
20+
#include "llvm/IR/BasicBlock.h"
21+
#include "llvm/IR/DerivedTypes.h"
22+
#include "llvm/IR/Function.h"
23+
#include "llvm/IR/IRBuilder.h"
24+
#include "llvm/IR/Module.h"
25+
#include "llvm/IR/Type.h"
26+
27+
using namespace llvm;
28+
29+
static inline std::string tofltstr(Type* T) {
30+
switch (T->getTypeID()) {
31+
case Type::HalfTyID: return "half";
32+
case Type::FloatTyID: return "float";
33+
case Type::DoubleTyID: return "double";
34+
case Type::X86_FP80TyID: return "x87d";
35+
case Type::FP128TyID: return "quad";
36+
case Type::PPC_FP128TyID: return "ppcddouble";
37+
default: llvm_unreachable("Invalid floating type");
38+
}
39+
}
40+
41+
//! Create function for type that is equivalent to memcpy but adds to destination rather
42+
//! than a direct copy; dst, src, numelems
43+
Function* getOrInsertDifferentialFloatMemcpy(Module& M, PointerType* T) {
44+
Type* elementType = T->getElementType();
45+
assert(elementType->isFloatingPointTy());
46+
std::string name = "__enzyme_memcpyadd_" + tofltstr(elementType);
47+
FunctionType* FT = FunctionType::get(Type::getVoidTy(M.getContext()), { T, T, Type::getInt64Ty(M.getContext()) }, false);
48+
49+
Function* F = cast<Function>(M.getOrInsertFunction(name, FT));
50+
51+
if (!F->empty()) return F;
52+
53+
F->setLinkage(Function::LinkageTypes::InternalLinkage);
54+
F->addFnAttr(Attribute::ArgMemOnly);
55+
F->addFnAttr(Attribute::NoUnwind);
56+
F->addParamAttr(0, Attribute::NoCapture);
57+
F->addParamAttr(1, Attribute::NoCapture);
58+
59+
BasicBlock* entry = BasicBlock::Create(M.getContext(), "entry", F);
60+
BasicBlock* body = BasicBlock::Create(M.getContext(), "for.body", F);
61+
BasicBlock* end = BasicBlock::Create(M.getContext(), "for.end", F);
62+
63+
auto dst = F->arg_begin();
64+
dst->setName("dst");
65+
auto src = dst+1;
66+
src->setName("src");
67+
auto num = src+1;
68+
num->setName("num");
69+
70+
{
71+
IRBuilder<> B(entry);
72+
B.CreateCondBr(B.CreateICmpEQ(num, ConstantInt::get(num->getType(), 0)), end, body);
73+
}
74+
75+
{
76+
IRBuilder<> B(body);
77+
B.setFastMathFlags(getFast());
78+
PHINode* idx = B.CreatePHI(num->getType(), 2, "idx");
79+
idx->addIncoming(ConstantInt::get(num->getType(), 0), entry);
80+
81+
Value* dsti = B.CreateGEP(dst, { idx }, "dst.i");
82+
Value* dstl = B.CreateLoad(dsti, "dst.i.l");
83+
B.CreateStore(Constant::getNullValue(elementType), dsti);
84+
85+
Value* srci = B.CreateGEP(src, { idx }, "src.i");
86+
Value* srcl = B.CreateLoad(srci, "src.i.l");
87+
B.CreateStore(B.CreateFAdd(srcl, dstl), srci);
88+
89+
Value* next = B.CreateNUWAdd(idx, ConstantInt::get(num->getType(), 1), "idx.next");
90+
idx->addIncoming(next, body);
91+
B.CreateCondBr(B.CreateICmpEQ(num, next), end, body);
92+
}
93+
94+
{
95+
IRBuilder<> B(end);
96+
B.CreateRetVoid();
97+
}
98+
return F;
99+
}

enzyme/Enzyme/Utils.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Utils.h
3+
*
4+
* Copyright (C) 2019 William S. Moses (enzyme@wsmoses.com) - All Rights Reserved
5+
*
6+
* For commercial use of this code please contact the author(s) above.
7+
*
8+
* For research use of the code please use the following citation.
9+
*
10+
* \misc{mosesenzyme,
11+
author = {William S. Moses, Tim Kaler},
12+
title = {Enzyme: LLVM Automatic Differentiation},
13+
year = {2019},
14+
howpublished = {\url{https://github.com/wsmoses/Enzyme/}},
15+
note = {commit xxxxxxx}
16+
*/
17+
18+
#ifndef ENZYME_UTILS_H
19+
#define ENZYME_UTILS_H
20+
21+
#include "llvm/IR/Function.h"
22+
#include "llvm/IR/Operator.h"
23+
#include "llvm/IR/Module.h"
24+
#include "llvm/IR/Type.h"
25+
26+
static inline llvm::FastMathFlags getFast() {
27+
llvm::FastMathFlags f;
28+
f.set();
29+
return f;
30+
}
31+
32+
//! Create function for type that performs the derivative memcpy on floating point memory
33+
llvm::Function* getOrInsertDifferentialFloatMemcpy(llvm::Module& M, llvm::PointerType* T);
34+
35+
#endif

0 commit comments

Comments
 (0)