Skip to content

Commit a37a72e

Browse files
committed
working memcpy on floats
1 parent 5fa4a24 commit a37a72e

File tree

7 files changed

+276
-42
lines changed

7 files changed

+276
-42
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.swp

bench/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
%-unopt.ll: %.c
66
#../build/bin/clang -fno-unroll-loops $^ -O1 -fno-vectorize -fno-slp-vectorize -S -emit-llvm -o - -ffast-math | ../build/bin/opt -early-cse-memssa - -S -o $@ -simplifycfg -simplifycfg
77
#../build/bin/clang -fno-unroll-loops $^ -O3 -fno-vectorize -fno-slp-vectorize -S -emit-llvm -o - -ffast-math | ../build/bin/opt -early-cse-memssa - -S -o $@ -simplifycfg -simplifycfg
8-
../build/bin/clang -fno-unroll-loops $^ -O0 -fno-vectorize -fno-slp-vectorize -S -emit-llvm -o - -ffast-math -o $@ #| ../build/bin/opt -early-cse-memssa - -S -o $@ -simplifycfg -simplifycfg
8+
../build/bin/clang -fno-unroll-loops $^ -O3 -fno-vectorize -fno-slp-vectorize -S -emit-llvm -o - -ffast-math -o $@ #| ../build/bin/opt -early-cse-memssa - -S -o $@ -simplifycfg -simplifycfg
99

1010
%-unopt.ll: %.cpp
1111
#../build/bin/clang++ -fno-unroll-loops $^ -O3 -fno-vectorize -fno-slp-vectorize -S -emit-llvm -o $@ -ffast-math -Wno-error=non-pod-varargs -DEIGEN_UNROLLING_LIMIT=0 -I ../../adept-2.0.5/include -I../../tapenade/ADFirstAidKit

enzyme/Enzyme/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
if (${LLVM_VERSION_MAJOR} LESS 8)
66
add_llvm_loadable_module( LLVMEnzyme-${LLVM_VERSION_MAJOR}
77
Enzyme.cpp
8+
Utils.cpp
89
SCEV/ScalarEvolutionExpander.cpp
910
DEPENDS
1011
intrinsics_gen
@@ -14,6 +15,7 @@ if (${LLVM_VERSION_MAJOR} LESS 8)
1415
else()
1516
add_llvm_library( LLVMEnzyme-${LLVM_VERSION_MAJOR}
1617
Enzyme.cpp
18+
Utils.cpp
1719
SCEV/ScalarEvolutionExpander.cpp
1820
MODULE
1921
DEPENDS

enzyme/Enzyme/Enzyme.cpp

Lines changed: 32 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include <llvm/Config/llvm-config.h>
2020

21+
#include "Utils.h"
2122
#include "SCEV/ScalarEvolutionExpander.h"
2223

2324
#include "llvm/Transforms/Utils/PromoteMemToReg.h"
@@ -141,12 +142,6 @@ std::string tostring(DIFFE_TYPE t) {
141142
}
142143
}
143144

144-
static inline FastMathFlags getFast() {
145-
FastMathFlags f;
146-
f.set();
147-
return f;
148-
}
149-
150145
Instruction *getNextNonDebugInstruction(Instruction* Z) {
151146
for (Instruction *I = Z->getNextNode(); I; I = I->getNextNode())
152147
if (!isa<DbgInfoIntrinsic>(I))
@@ -349,21 +344,22 @@ bool isIntASecretFloat(Value* val) {
349344
assert(0 && "unsure if constant or not");
350345
}
351346

352-
bool isIntPointerASecretFloat(Value* val) {
347+
//! return the secret float type if found, otherwise nullptr
348+
Type* isIntPointerASecretFloat(Value* val) {
353349
assert(val->getType()->isPointerTy());
354350
assert(cast<PointerType>(val->getType())->getElementType()->isIntegerTy());
355351

356-
if (isa<UndefValue>(val)) return true;
352+
if (isa<UndefValue>(val)) return nullptr;
357353

358354
if (auto cint = dyn_cast<ConstantInt>(val)) {
359-
if (!cint->isZero()) return false;
355+
if (!cint->isZero()) return nullptr;
360356
assert(0 && "unsure if constant or not because constantint");
361357
//if (cint->isOne()) return cint;
362358
}
363359

364360

365361
if (auto inst = dyn_cast<Instruction>(val)) {
366-
bool floatingUse = false;
362+
Type* floatingUse = nullptr;
367363
bool pointerUse = false;
368364
SmallPtrSet<Value*, 4> seen;
369365

@@ -374,7 +370,11 @@ bool isIntPointerASecretFloat(Value* val) {
374370
do {
375371
Type* let = cast<PointerType>(v->getType())->getElementType();
376372
if (let->isFloatingPointTy()) {
377-
floatingUse = true;
373+
if (floatingUse == nullptr) {
374+
floatingUse = let;
375+
} else {
376+
assert(floatingUse == let);
377+
}
378378
}
379379
if (auto ci = dyn_cast<CastInst>(v)) {
380380
if (auto cal = dyn_cast<CallInst>(ci->getOperand(0))) {
@@ -409,7 +409,11 @@ bool isIntPointerASecretFloat(Value* val) {
409409
llvm::errs() << " for val " << *v << *et << "\n";
410410

411411
if (et->isFloatingPointTy()) {
412-
floatingUse = true;
412+
if (floatingUse == nullptr) {
413+
floatingUse = et;
414+
} else {
415+
assert(floatingUse == et);
416+
}
413417
}
414418
if (et->isPointerTy()) {
415419
pointerUse = true;
@@ -431,8 +435,8 @@ bool isIntPointerASecretFloat(Value* val) {
431435
}
432436
}
433437

434-
if (pointerUse && !floatingUse) return false;
435-
if (!pointerUse && floatingUse) return true;
438+
if (pointerUse && (floatingUse == nullptr)) return nullptr;
439+
if (!pointerUse && (floatingUse != nullptr)) return floatingUse;
436440
llvm::errs() << *inst->getParent()->getParent() << "\n";
437441
llvm::errs() << " val:" << *val << " pointer:" << pointerUse << " floating:" << floatingUse << "\n";
438442
assert(0 && "ambiguous unsure if constant or not");
@@ -4057,7 +4061,20 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
40574061
switch(op->getIntrinsicID()) {
40584062
case Intrinsic::memcpy: {
40594063
if (gutils->isConstantInstruction(inst)) continue;
4060-
if (!isIntPointerASecretFloat(op->getOperand(0)) ) {
4064+
if (Type* secretty = isIntPointerASecretFloat(op->getOperand(0)) ) {
4065+
SmallVector<Value*, 4> args;
4066+
auto secretpt = PointerType::getUnqual(secretty);
4067+
4068+
args.push_back(Builder2.CreatePointerCast(invertPointer(op->getOperand(0)), secretpt));
4069+
args.push_back(Builder2.CreatePointerCast(invertPointer(op->getOperand(1)), secretpt));
4070+
args.push_back(lookup(op->getOperand(2)));
4071+
auto dmemcpy = getOrInsertDifferentialFloatMemcpy(*M, secretpt);
4072+
dmemcpy->dump();
4073+
args[0]->dump();
4074+
args[1]->dump();
4075+
args[2]->dump();
4076+
auto cal = Builder2.CreateCall(dmemcpy, args);
4077+
} else {
40614078
if (topLevel) {
40624079
SmallVector<Value*, 4> args;
40634080
IRBuilder <>BuilderZ(op);
@@ -4072,32 +4089,6 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
40724089
cal->setCallingConv(op->getCallingConv());
40734090
cal->setTailCallKind(op->getTailCallKind());
40744091
}
4075-
} else {
4076-
//no change to forward pass if represents float
4077-
//Zero the destination
4078-
assert(0 && "TODO: memcpy has bug that needs fixing (per int double vs ptr)");
4079-
/*
4080-
{
4081-
TODO BECOME MEMSET
4082-
SmallVector<Value*, 4> args;
4083-
// source and dest are swapped
4084-
args.push_back(invertPointer(op->getOperand(1)));
4085-
args.push_back(invertPointer(op->getOperand(0)));
4086-
args.push_back(lookup(op->getOperand(2)));
4087-
args.push_back(lookup(op->getOperand(3)));
4088-
4089-
Type *tys[] = {args[0]->getType(), args[1]->getType(), args[2]->getType()};
4090-
auto cal = Builder2.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::memset, tys), args);
4091-
cal->setAttributes(op->getAttributes());
4092-
cal->setCallingConv(op->getCallingConv());
4093-
cal->setTailCallKind(op->getTailCallKind());
4094-
}
4095-
4096-
4097-
{
4098-
4099-
}
4100-
*/
41014092
}
41024093
break;
41034094
}

enzyme/Enzyme/Utils.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Utils.cpp
3+
*
4+
* Copyright (C) 2019 William S. Moses (enzyme@wsmoses.com) - All Rights Reserved
5+
*
6+
* For commercial use of this code please contact the author(s) above.
7+
*
8+
* For research use of the code please use the following citation.
9+
*
10+
* \misc{mosesenzyme,
11+
author = {William S. Moses, Tim Kaler},
12+
title = {Enzyme: LLVM Automatic Differentiation},
13+
year = {2019},
14+
howpublished = {\url{https://github.com/wsmoses/Enzyme/}},
15+
note = {commit xxxxxxx}
16+
*/
17+
18+
#include "Utils.h"
19+
20+
#include "llvm/IR/BasicBlock.h"
21+
#include "llvm/IR/DerivedTypes.h"
22+
#include "llvm/IR/Function.h"
23+
#include "llvm/IR/IRBuilder.h"
24+
#include "llvm/IR/Module.h"
25+
#include "llvm/IR/Type.h"
26+
27+
using namespace llvm;
28+
29+
static inline std::string tofltstr(Type* T) {
30+
switch (T->getTypeID()) {
31+
case Type::HalfTyID: return "half";
32+
case Type::FloatTyID: return "float";
33+
case Type::DoubleTyID: return "double";
34+
case Type::X86_FP80TyID: return "x87d";
35+
case Type::FP128TyID: return "quad";
36+
case Type::PPC_FP128TyID: return "ppcddouble";
37+
default: llvm_unreachable("Invalid floating type");
38+
}
39+
}
40+
41+
//! Create function for type that is equivalent to memcpy but adds to destination rather
42+
//! than a direct copy; dst, src, numelems
43+
Function* getOrInsertDifferentialFloatMemcpy(Module& M, PointerType* T) {
44+
Type* elementType = T->getElementType();
45+
assert(elementType->isFloatingPointTy());
46+
std::string name = "__enzyme_memcpyadd_" + tofltstr(elementType);
47+
FunctionType* FT = FunctionType::get(Type::getVoidTy(M.getContext()), { T, T, Type::getInt64Ty(M.getContext()) }, false);
48+
49+
Function* F = cast<Function>(M.getOrInsertFunction(name, FT));
50+
51+
if (!F->empty()) return F;
52+
53+
F->setLinkage(Function::LinkageTypes::InternalLinkage);
54+
F->addFnAttr(Attribute::ArgMemOnly);
55+
F->addFnAttr(Attribute::NoUnwind);
56+
F->addParamAttr(0, Attribute::NoCapture);
57+
F->addParamAttr(1, Attribute::NoCapture);
58+
59+
BasicBlock* entry = BasicBlock::Create(M.getContext(), "entry", F);
60+
BasicBlock* body = BasicBlock::Create(M.getContext(), "for.body", F);
61+
BasicBlock* end = BasicBlock::Create(M.getContext(), "for.end", F);
62+
63+
auto dst = F->arg_begin();
64+
dst->setName("dst");
65+
auto src = dst+1;
66+
src->setName("src");
67+
auto num = src+1;
68+
num->setName("num");
69+
70+
{
71+
IRBuilder<> B(entry);
72+
B.CreateCondBr(B.CreateICmpEQ(num, ConstantInt::get(num->getType(), 0)), end, body);
73+
}
74+
75+
{
76+
IRBuilder<> B(body);
77+
B.setFastMathFlags(getFast());
78+
PHINode* idx = B.CreatePHI(num->getType(), 2, "idx");
79+
idx->addIncoming(ConstantInt::get(num->getType(), 0), entry);
80+
81+
Value* dsti = B.CreateGEP(dst, { idx }, "dst.i");
82+
Value* dstl = B.CreateLoad(dsti, "dst.i.l");
83+
B.CreateStore(Constant::getNullValue(elementType), dsti);
84+
85+
Value* srci = B.CreateGEP(src, { idx }, "src.i");
86+
Value* srcl = B.CreateLoad(srci, "src.i.l");
87+
B.CreateStore(B.CreateFAdd(srcl, dstl), srci);
88+
89+
Value* next = B.CreateNUWAdd(idx, ConstantInt::get(num->getType(), 1), "idx.next");
90+
idx->addIncoming(next, body);
91+
B.CreateCondBr(B.CreateICmpEQ(num, next), end, body);
92+
}
93+
94+
{
95+
IRBuilder<> B(end);
96+
B.CreateRetVoid();
97+
}
98+
return F;
99+
}

enzyme/Enzyme/Utils.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Utils.h
3+
*
4+
* Copyright (C) 2019 William S. Moses (enzyme@wsmoses.com) - All Rights Reserved
5+
*
6+
* For commercial use of this code please contact the author(s) above.
7+
*
8+
* For research use of the code please use the following citation.
9+
*
10+
* \misc{mosesenzyme,
11+
author = {William S. Moses, Tim Kaler},
12+
title = {Enzyme: LLVM Automatic Differentiation},
13+
year = {2019},
14+
howpublished = {\url{https://github.com/wsmoses/Enzyme/}},
15+
note = {commit xxxxxxx}
16+
*/
17+
18+
#ifndef ENZYME_UTILS_H
19+
#define ENZYME_UTILS_H
20+
21+
#include "llvm/IR/Function.h"
22+
#include "llvm/IR/Operator.h"
23+
#include "llvm/IR/Module.h"
24+
#include "llvm/IR/Type.h"
25+
26+
static inline llvm::FastMathFlags getFast() {
27+
llvm::FastMathFlags f;
28+
f.set();
29+
return f;
30+
}
31+
32+
//! Create function for type that performs the derivative memcpy on floating point memory
33+
llvm::Function* getOrInsertDifferentialFloatMemcpy(llvm::Module& M, llvm::PointerType* T);
34+
35+
#endif

0 commit comments

Comments
 (0)