Skip to content

Commit 25048e5

Browse files
authored
Refactor blas (rust-lang#389)
1 parent a16fe1b commit 25048e5

File tree

1 file changed

+188
-174
lines changed

1 file changed

+188
-174
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 188 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -4107,6 +4107,191 @@ class AdjointGenerator
41074107
}
41084108
}
41094109

4110+
bool handleBLAS(llvm::CallInst &call, Function *called, StringRef funcName,
4111+
const std::map<Argument *, bool> &uncacheable_args) {
4112+
CallInst *const newCall = cast<CallInst>(gutils->getNewFromOriginal(&call));
4113+
IRBuilder<> BuilderZ(newCall);
4114+
BuilderZ.setFastMathFlags(getFast());
4115+
4116+
if ((funcName == "cblas_ddot" || funcName == "cblas_sdot") &&
4117+
called->isDeclaration()) {
4118+
Type *innerType;
4119+
std::string dfuncName;
4120+
if (funcName == "cblas_ddot") {
4121+
innerType = Type::getDoubleTy(call.getContext());
4122+
dfuncName = "cblas_daxpy";
4123+
} else if (funcName == "cblas_sdot") {
4124+
innerType = Type::getFloatTy(call.getContext());
4125+
dfuncName = "cblas_saxpy";
4126+
} else {
4127+
assert(false && "Unreachable");
4128+
}
4129+
Type *castvals[2] = {call.getArgOperand(1)->getType(),
4130+
call.getArgOperand(3)->getType()};
4131+
auto *cachetype =
4132+
StructType::get(call.getContext(), ArrayRef<Type *>(castvals));
4133+
Value *undefinit = UndefValue::get(cachetype);
4134+
Value *cacheval;
4135+
auto in_arg = call.getCalledFunction()->arg_begin();
4136+
in_arg++;
4137+
Argument *xfuncarg = in_arg;
4138+
in_arg++;
4139+
in_arg++;
4140+
Argument *yfuncarg = in_arg;
4141+
bool xcache = !gutils->isConstantValue(call.getArgOperand(3)) &&
4142+
uncacheable_args.find(xfuncarg)->second;
4143+
bool ycache = !gutils->isConstantValue(call.getArgOperand(1)) &&
4144+
uncacheable_args.find(yfuncarg)->second;
4145+
if ((Mode == DerivativeMode::ReverseModeCombined ||
4146+
Mode == DerivativeMode::ReverseModePrimal) &&
4147+
(xcache || ycache)) {
4148+
4149+
Value *arg1, *arg2;
4150+
auto size = ConstantExpr::getSizeOf(innerType);
4151+
if (xcache) {
4152+
auto dmemcpy =
4153+
getOrInsertMemcpyStrided(*gutils->oldFunc->getParent(),
4154+
PointerType::getUnqual(innerType), 0, 0);
4155+
auto malins = CallInst::CreateMalloc(
4156+
gutils->getNewFromOriginal(&call), size->getType(), innerType,
4157+
size, call.getArgOperand(0), nullptr, "");
4158+
arg1 =
4159+
BuilderZ.CreateBitCast(malins, call.getArgOperand(1)->getType());
4160+
SmallVector<Value *, 4> args;
4161+
args.push_back(arg1);
4162+
args.push_back(gutils->getNewFromOriginal(call.getArgOperand(1)));
4163+
args.push_back(call.getArgOperand(0));
4164+
args.push_back(call.getArgOperand(2));
4165+
BuilderZ.CreateCall(dmemcpy, args);
4166+
}
4167+
if (ycache) {
4168+
auto dmemcpy =
4169+
getOrInsertMemcpyStrided(*gutils->oldFunc->getParent(),
4170+
PointerType::getUnqual(innerType), 0, 0);
4171+
auto malins = CallInst::CreateMalloc(
4172+
gutils->getNewFromOriginal(&call), size->getType(), innerType,
4173+
size, call.getArgOperand(0), nullptr, "");
4174+
arg2 =
4175+
BuilderZ.CreateBitCast(malins, call.getArgOperand(3)->getType());
4176+
SmallVector<Value *, 4> args;
4177+
args.push_back(arg2);
4178+
args.push_back(gutils->getNewFromOriginal(call.getArgOperand(3)));
4179+
args.push_back(call.getArgOperand(0));
4180+
args.push_back(call.getArgOperand(4));
4181+
BuilderZ.CreateCall(dmemcpy, args);
4182+
}
4183+
if (xcache && ycache) {
4184+
auto valins1 = BuilderZ.CreateInsertValue(undefinit, arg1, 0);
4185+
cacheval = BuilderZ.CreateInsertValue(valins1, arg2, 1);
4186+
} else if (xcache)
4187+
cacheval = arg1;
4188+
else {
4189+
assert(ycache);
4190+
cacheval = arg2;
4191+
}
4192+
gutils->cacheForReverse(BuilderZ, cacheval,
4193+
getIndex(&call, CacheType::Tape));
4194+
}
4195+
if (Mode == DerivativeMode::ReverseModeCombined ||
4196+
Mode == DerivativeMode::ReverseModeGradient) {
4197+
IRBuilder<> Builder2(call.getParent());
4198+
getReverseBuilder(Builder2);
4199+
auto derivcall = gutils->oldFunc->getParent()->getOrInsertFunction(
4200+
dfuncName, Builder2.getVoidTy(), Builder2.getInt32Ty(), innerType,
4201+
call.getArgOperand(1)->getType(), Builder2.getInt32Ty(),
4202+
call.getArgOperand(3)->getType(), Builder2.getInt32Ty());
4203+
Value *structarg1;
4204+
Value *structarg2;
4205+
if (xcache || ycache) {
4206+
if (Mode == DerivativeMode::ReverseModeGradient &&
4207+
(!gutils->isConstantValue(call.getArgOperand(1)) ||
4208+
!gutils->isConstantValue(call.getArgOperand(3)))) {
4209+
cacheval = BuilderZ.CreatePHI(cachetype, 0);
4210+
}
4211+
cacheval =
4212+
lookup(gutils->cacheForReverse(BuilderZ, cacheval,
4213+
getIndex(&call, CacheType::Tape)),
4214+
Builder2);
4215+
if (xcache && ycache) {
4216+
structarg1 = BuilderZ.CreateExtractValue(cacheval, 0);
4217+
structarg2 = BuilderZ.CreateExtractValue(cacheval, 1);
4218+
} else if (xcache)
4219+
structarg1 = cacheval;
4220+
else if (ycache)
4221+
structarg2 = cacheval;
4222+
}
4223+
if (!xcache)
4224+
structarg1 = lookup(gutils->getNewFromOriginal(call.getArgOperand(1)),
4225+
Builder2);
4226+
if (!ycache)
4227+
structarg2 = lookup(gutils->getNewFromOriginal(call.getArgOperand(3)),
4228+
Builder2);
4229+
CallInst *firstdcall, *seconddcall;
4230+
if (!gutils->isConstantValue(call.getArgOperand(3))) {
4231+
Value *estride;
4232+
if (xcache)
4233+
estride = Builder2.getInt32(1);
4234+
else
4235+
estride = lookup(gutils->getNewFromOriginal(call.getArgOperand(2)),
4236+
Builder2);
4237+
SmallVector<Value *, 6> args1 = {
4238+
lookup(gutils->getNewFromOriginal(call.getArgOperand(0)),
4239+
Builder2),
4240+
diffe(&call, Builder2),
4241+
structarg1,
4242+
estride,
4243+
lookup(gutils->invertPointerM(call.getArgOperand(3), Builder2),
4244+
Builder2),
4245+
lookup(gutils->getNewFromOriginal(call.getArgOperand(4)),
4246+
Builder2)};
4247+
firstdcall = Builder2.CreateCall(derivcall, args1);
4248+
}
4249+
if (!gutils->isConstantValue(call.getArgOperand(1))) {
4250+
Value *estride;
4251+
if (ycache)
4252+
estride = Builder2.getInt32(1);
4253+
else
4254+
estride = lookup(gutils->getNewFromOriginal(call.getArgOperand(4)),
4255+
Builder2);
4256+
SmallVector<Value *, 6> args2 = {
4257+
lookup(gutils->getNewFromOriginal(call.getArgOperand(0)),
4258+
Builder2),
4259+
diffe(&call, Builder2),
4260+
structarg2,
4261+
estride,
4262+
lookup(gutils->invertPointerM(call.getArgOperand(1), Builder2),
4263+
Builder2),
4264+
lookup(gutils->getNewFromOriginal(call.getArgOperand(2)),
4265+
Builder2)};
4266+
seconddcall = Builder2.CreateCall(derivcall, args2);
4267+
}
4268+
setDiffe(&call, Constant::getNullValue(call.getType()), Builder2);
4269+
if (shouldFree()) {
4270+
if (xcache)
4271+
CallInst::CreateFree(structarg1, firstdcall->getNextNode());
4272+
if (ycache)
4273+
CallInst::CreateFree(structarg2, seconddcall->getNextNode());
4274+
}
4275+
}
4276+
4277+
if (gutils->knownRecomputeHeuristic.find(&call) !=
4278+
gutils->knownRecomputeHeuristic.end()) {
4279+
if (!gutils->knownRecomputeHeuristic[&call]) {
4280+
gutils->cacheForReverse(BuilderZ, newCall,
4281+
getIndex(&call, CacheType::Self));
4282+
}
4283+
}
4284+
4285+
if (Mode == DerivativeMode::ReverseModeGradient) {
4286+
eraseIfUnused(call, /*erase*/ true, /*check*/ false);
4287+
} else {
4288+
eraseIfUnused(call);
4289+
}
4290+
return true;
4291+
}
4292+
return false;
4293+
}
4294+
41104295
void handleMPI(llvm::CallInst &call, Function *called, StringRef funcName) {
41114296
assert(Mode != DerivativeMode::ForwardMode);
41124297
assert(called);
@@ -6018,180 +6203,9 @@ class AdjointGenerator
60186203
return;
60196204
}
60206205

6021-
if ((funcName == "cblas_ddot" || funcName == "cblas_sdot") &&
6022-
called->isDeclaration()) {
6023-
Type *innerType;
6024-
std::string dfuncName;
6025-
if (funcName == "cblas_ddot") {
6026-
innerType = Type::getDoubleTy(call.getContext());
6027-
dfuncName = "cblas_daxpy";
6028-
} else if (funcName == "cblas_sdot") {
6029-
innerType = Type::getFloatTy(call.getContext());
6030-
dfuncName = "cblas_saxpy";
6031-
} else {
6032-
assert(false && "Unreachable");
6033-
}
6034-
Type *castvals[2] = {call.getArgOperand(1)->getType(),
6035-
call.getArgOperand(3)->getType()};
6036-
auto *cachetype =
6037-
StructType::get(call.getContext(), ArrayRef<Type *>(castvals));
6038-
Value *undefinit = UndefValue::get(cachetype);
6039-
Value *cacheval;
6040-
auto in_arg = call.getCalledFunction()->arg_begin();
6041-
in_arg++;
6042-
Argument *xfuncarg = in_arg;
6043-
in_arg++;
6044-
in_arg++;
6045-
Argument *yfuncarg = in_arg;
6046-
bool xcache = !gutils->isConstantValue(call.getArgOperand(3)) &&
6047-
uncacheable_args.find(xfuncarg)->second;
6048-
bool ycache = !gutils->isConstantValue(call.getArgOperand(1)) &&
6049-
uncacheable_args.find(yfuncarg)->second;
6050-
if ((Mode == DerivativeMode::ReverseModeCombined ||
6051-
Mode == DerivativeMode::ReverseModePrimal) &&
6052-
(xcache || ycache)) {
6053-
Value *arg1, *arg2;
6054-
auto size = ConstantExpr::getSizeOf(innerType);
6055-
if (xcache) {
6056-
auto dmemcpy =
6057-
getOrInsertMemcpyStrided(*gutils->oldFunc->getParent(),
6058-
PointerType::getUnqual(innerType), 0, 0);
6059-
auto malins = CallInst::CreateMalloc(
6060-
gutils->getNewFromOriginal(&call), size->getType(), innerType,
6061-
size, call.getArgOperand(0), nullptr, "");
6062-
arg1 =
6063-
BuilderZ.CreateBitCast(malins, call.getArgOperand(1)->getType());
6064-
SmallVector<Value *, 4> args;
6065-
args.push_back(arg1);
6066-
args.push_back(gutils->getNewFromOriginal(call.getArgOperand(1)));
6067-
args.push_back(call.getArgOperand(0));
6068-
args.push_back(call.getArgOperand(2));
6069-
BuilderZ.CreateCall(dmemcpy, args);
6070-
}
6071-
if (ycache) {
6072-
auto dmemcpy =
6073-
getOrInsertMemcpyStrided(*gutils->oldFunc->getParent(),
6074-
PointerType::getUnqual(innerType), 0, 0);
6075-
auto malins = CallInst::CreateMalloc(
6076-
gutils->getNewFromOriginal(&call), size->getType(), innerType,
6077-
size, call.getArgOperand(0), nullptr, "");
6078-
arg2 =
6079-
BuilderZ.CreateBitCast(malins, call.getArgOperand(3)->getType());
6080-
SmallVector<Value *, 4> args;
6081-
args.push_back(arg2);
6082-
args.push_back(gutils->getNewFromOriginal(call.getArgOperand(3)));
6083-
args.push_back(call.getArgOperand(0));
6084-
args.push_back(call.getArgOperand(4));
6085-
BuilderZ.CreateCall(dmemcpy, args);
6086-
}
6087-
if (xcache && ycache) {
6088-
auto valins1 = BuilderZ.CreateInsertValue(undefinit, arg1, 0);
6089-
cacheval = BuilderZ.CreateInsertValue(valins1, arg2, 1);
6090-
} else if (xcache)
6091-
cacheval = arg1;
6092-
else {
6093-
assert(ycache);
6094-
cacheval = arg2;
6095-
}
6096-
gutils->cacheForReverse(BuilderZ, cacheval,
6097-
getIndex(&call, CacheType::Tape));
6098-
}
6099-
if (Mode == DerivativeMode::ReverseModeCombined ||
6100-
Mode == DerivativeMode::ReverseModeGradient) {
6101-
IRBuilder<> Builder2(call.getParent());
6102-
getReverseBuilder(Builder2);
6103-
auto derivcall = gutils->oldFunc->getParent()->getOrInsertFunction(
6104-
dfuncName, Builder2.getVoidTy(), Builder2.getInt32Ty(), innerType,
6105-
call.getArgOperand(1)->getType(), Builder2.getInt32Ty(),
6106-
call.getArgOperand(3)->getType(), Builder2.getInt32Ty());
6107-
Value *structarg1;
6108-
Value *structarg2;
6109-
if (xcache || ycache) {
6110-
if (Mode == DerivativeMode::ReverseModeGradient &&
6111-
(!gutils->isConstantValue(call.getArgOperand(1)) ||
6112-
!gutils->isConstantValue(call.getArgOperand(3)))) {
6113-
cacheval = BuilderZ.CreatePHI(cachetype, 0);
6114-
}
6115-
cacheval =
6116-
lookup(gutils->cacheForReverse(BuilderZ, cacheval,
6117-
getIndex(&call, CacheType::Tape)),
6118-
Builder2);
6119-
if (xcache && ycache) {
6120-
structarg1 = BuilderZ.CreateExtractValue(cacheval, 0);
6121-
structarg2 = BuilderZ.CreateExtractValue(cacheval, 1);
6122-
} else if (xcache)
6123-
structarg1 = cacheval;
6124-
else if (ycache)
6125-
structarg2 = cacheval;
6126-
}
6127-
if (!xcache)
6128-
structarg1 = lookup(
6129-
gutils->getNewFromOriginal(orig->getArgOperand(1)), Builder2);
6130-
if (!ycache)
6131-
structarg2 = lookup(
6132-
gutils->getNewFromOriginal(orig->getArgOperand(3)), Builder2);
6133-
CallInst *firstdcall, *seconddcall;
6134-
if (!gutils->isConstantValue(call.getArgOperand(3))) {
6135-
Value *estride;
6136-
if (xcache)
6137-
estride = Builder2.getInt32(1);
6138-
else
6139-
estride = lookup(gutils->getNewFromOriginal(orig->getArgOperand(2)),
6140-
Builder2);
6141-
SmallVector<Value *, 6> args1 = {
6142-
lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)),
6143-
Builder2),
6144-
diffe(orig, Builder2),
6145-
structarg1,
6146-
estride,
6147-
lookup(gutils->invertPointerM(orig->getArgOperand(3), Builder2),
6148-
Builder2),
6149-
lookup(gutils->getNewFromOriginal(orig->getArgOperand(4)),
6150-
Builder2)};
6151-
firstdcall = Builder2.CreateCall(derivcall, args1);
6152-
}
6153-
if (!gutils->isConstantValue(call.getArgOperand(1))) {
6154-
Value *estride;
6155-
if (ycache)
6156-
estride = Builder2.getInt32(1);
6157-
else
6158-
estride = lookup(gutils->getNewFromOriginal(orig->getArgOperand(4)),
6159-
Builder2);
6160-
SmallVector<Value *, 6> args2 = {
6161-
lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)),
6162-
Builder2),
6163-
diffe(orig, Builder2),
6164-
structarg2,
6165-
estride,
6166-
lookup(gutils->invertPointerM(orig->getArgOperand(1), Builder2),
6167-
Builder2),
6168-
lookup(gutils->getNewFromOriginal(orig->getArgOperand(2)),
6169-
Builder2)};
6170-
seconddcall = Builder2.CreateCall(derivcall, args2);
6171-
}
6172-
setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2);
6173-
if (shouldFree()) {
6174-
if (xcache)
6175-
CallInst::CreateFree(structarg1, firstdcall->getNextNode());
6176-
if (ycache)
6177-
CallInst::CreateFree(structarg2, seconddcall->getNextNode());
6178-
}
6179-
}
6180-
6181-
if (gutils->knownRecomputeHeuristic.find(orig) !=
6182-
gutils->knownRecomputeHeuristic.end()) {
6183-
if (!gutils->knownRecomputeHeuristic[orig]) {
6184-
gutils->cacheForReverse(BuilderZ, newCall,
6185-
getIndex(orig, CacheType::Self));
6186-
}
6187-
}
6188-
6189-
if (Mode == DerivativeMode::ReverseModeGradient) {
6190-
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
6191-
} else {
6192-
eraseIfUnused(*orig);
6193-
}
6194-
return;
6206+
if ((funcName == "cblas_ddot" || funcName == "cblas_sdot")) {
6207+
if (handleBLAS(call, called, funcName, uncacheable_args))
6208+
return;
61956209
}
61966210

61976211
if (funcName == "printf" || funcName == "puts" ||

0 commit comments

Comments
 (0)