@@ -4107,6 +4107,191 @@ class AdjointGenerator
4107
4107
}
4108
4108
}
4109
4109
4110
+ bool handleBLAS (llvm::CallInst &call, Function *called, StringRef funcName,
4111
+ const std::map<Argument *, bool > &uncacheable_args) {
4112
+ CallInst *const newCall = cast<CallInst>(gutils->getNewFromOriginal (&call));
4113
+ IRBuilder<> BuilderZ (newCall);
4114
+ BuilderZ.setFastMathFlags (getFast ());
4115
+
4116
+ if ((funcName == " cblas_ddot" || funcName == " cblas_sdot" ) &&
4117
+ called->isDeclaration ()) {
4118
+ Type *innerType;
4119
+ std::string dfuncName;
4120
+ if (funcName == " cblas_ddot" ) {
4121
+ innerType = Type::getDoubleTy (call.getContext ());
4122
+ dfuncName = " cblas_daxpy" ;
4123
+ } else if (funcName == " cblas_sdot" ) {
4124
+ innerType = Type::getFloatTy (call.getContext ());
4125
+ dfuncName = " cblas_saxpy" ;
4126
+ } else {
4127
+ assert (false && " Unreachable" );
4128
+ }
4129
+ Type *castvals[2 ] = {call.getArgOperand (1 )->getType (),
4130
+ call.getArgOperand (3 )->getType ()};
4131
+ auto *cachetype =
4132
+ StructType::get (call.getContext (), ArrayRef<Type *>(castvals));
4133
+ Value *undefinit = UndefValue::get (cachetype);
4134
+ Value *cacheval;
4135
+ auto in_arg = call.getCalledFunction ()->arg_begin ();
4136
+ in_arg++;
4137
+ Argument *xfuncarg = in_arg;
4138
+ in_arg++;
4139
+ in_arg++;
4140
+ Argument *yfuncarg = in_arg;
4141
+ bool xcache = !gutils->isConstantValue (call.getArgOperand (3 )) &&
4142
+ uncacheable_args.find (xfuncarg)->second ;
4143
+ bool ycache = !gutils->isConstantValue (call.getArgOperand (1 )) &&
4144
+ uncacheable_args.find (yfuncarg)->second ;
4145
+ if ((Mode == DerivativeMode::ReverseModeCombined ||
4146
+ Mode == DerivativeMode::ReverseModePrimal) &&
4147
+ (xcache || ycache)) {
4148
+
4149
+ Value *arg1, *arg2;
4150
+ auto size = ConstantExpr::getSizeOf (innerType);
4151
+ if (xcache) {
4152
+ auto dmemcpy =
4153
+ getOrInsertMemcpyStrided (*gutils->oldFunc ->getParent (),
4154
+ PointerType::getUnqual (innerType), 0 , 0 );
4155
+ auto malins = CallInst::CreateMalloc (
4156
+ gutils->getNewFromOriginal (&call), size->getType (), innerType,
4157
+ size, call.getArgOperand (0 ), nullptr , " " );
4158
+ arg1 =
4159
+ BuilderZ.CreateBitCast (malins, call.getArgOperand (1 )->getType ());
4160
+ SmallVector<Value *, 4 > args;
4161
+ args.push_back (arg1);
4162
+ args.push_back (gutils->getNewFromOriginal (call.getArgOperand (1 )));
4163
+ args.push_back (call.getArgOperand (0 ));
4164
+ args.push_back (call.getArgOperand (2 ));
4165
+ BuilderZ.CreateCall (dmemcpy, args);
4166
+ }
4167
+ if (ycache) {
4168
+ auto dmemcpy =
4169
+ getOrInsertMemcpyStrided (*gutils->oldFunc ->getParent (),
4170
+ PointerType::getUnqual (innerType), 0 , 0 );
4171
+ auto malins = CallInst::CreateMalloc (
4172
+ gutils->getNewFromOriginal (&call), size->getType (), innerType,
4173
+ size, call.getArgOperand (0 ), nullptr , " " );
4174
+ arg2 =
4175
+ BuilderZ.CreateBitCast (malins, call.getArgOperand (3 )->getType ());
4176
+ SmallVector<Value *, 4 > args;
4177
+ args.push_back (arg2);
4178
+ args.push_back (gutils->getNewFromOriginal (call.getArgOperand (3 )));
4179
+ args.push_back (call.getArgOperand (0 ));
4180
+ args.push_back (call.getArgOperand (4 ));
4181
+ BuilderZ.CreateCall (dmemcpy, args);
4182
+ }
4183
+ if (xcache && ycache) {
4184
+ auto valins1 = BuilderZ.CreateInsertValue (undefinit, arg1, 0 );
4185
+ cacheval = BuilderZ.CreateInsertValue (valins1, arg2, 1 );
4186
+ } else if (xcache)
4187
+ cacheval = arg1;
4188
+ else {
4189
+ assert (ycache);
4190
+ cacheval = arg2;
4191
+ }
4192
+ gutils->cacheForReverse (BuilderZ, cacheval,
4193
+ getIndex (&call, CacheType::Tape));
4194
+ }
4195
+ if (Mode == DerivativeMode::ReverseModeCombined ||
4196
+ Mode == DerivativeMode::ReverseModeGradient) {
4197
+ IRBuilder<> Builder2 (call.getParent ());
4198
+ getReverseBuilder (Builder2);
4199
+ auto derivcall = gutils->oldFunc ->getParent ()->getOrInsertFunction (
4200
+ dfuncName, Builder2.getVoidTy (), Builder2.getInt32Ty (), innerType,
4201
+ call.getArgOperand (1 )->getType (), Builder2.getInt32Ty (),
4202
+ call.getArgOperand (3 )->getType (), Builder2.getInt32Ty ());
4203
+ Value *structarg1;
4204
+ Value *structarg2;
4205
+ if (xcache || ycache) {
4206
+ if (Mode == DerivativeMode::ReverseModeGradient &&
4207
+ (!gutils->isConstantValue (call.getArgOperand (1 )) ||
4208
+ !gutils->isConstantValue (call.getArgOperand (3 )))) {
4209
+ cacheval = BuilderZ.CreatePHI (cachetype, 0 );
4210
+ }
4211
+ cacheval =
4212
+ lookup (gutils->cacheForReverse (BuilderZ, cacheval,
4213
+ getIndex (&call, CacheType::Tape)),
4214
+ Builder2);
4215
+ if (xcache && ycache) {
4216
+ structarg1 = BuilderZ.CreateExtractValue (cacheval, 0 );
4217
+ structarg2 = BuilderZ.CreateExtractValue (cacheval, 1 );
4218
+ } else if (xcache)
4219
+ structarg1 = cacheval;
4220
+ else if (ycache)
4221
+ structarg2 = cacheval;
4222
+ }
4223
+ if (!xcache)
4224
+ structarg1 = lookup (gutils->getNewFromOriginal (call.getArgOperand (1 )),
4225
+ Builder2);
4226
+ if (!ycache)
4227
+ structarg2 = lookup (gutils->getNewFromOriginal (call.getArgOperand (3 )),
4228
+ Builder2);
4229
+ CallInst *firstdcall, *seconddcall;
4230
+ if (!gutils->isConstantValue (call.getArgOperand (3 ))) {
4231
+ Value *estride;
4232
+ if (xcache)
4233
+ estride = Builder2.getInt32 (1 );
4234
+ else
4235
+ estride = lookup (gutils->getNewFromOriginal (call.getArgOperand (2 )),
4236
+ Builder2);
4237
+ SmallVector<Value *, 6 > args1 = {
4238
+ lookup (gutils->getNewFromOriginal (call.getArgOperand (0 )),
4239
+ Builder2),
4240
+ diffe (&call, Builder2),
4241
+ structarg1,
4242
+ estride,
4243
+ lookup (gutils->invertPointerM (call.getArgOperand (3 ), Builder2),
4244
+ Builder2),
4245
+ lookup (gutils->getNewFromOriginal (call.getArgOperand (4 )),
4246
+ Builder2)};
4247
+ firstdcall = Builder2.CreateCall (derivcall, args1);
4248
+ }
4249
+ if (!gutils->isConstantValue (call.getArgOperand (1 ))) {
4250
+ Value *estride;
4251
+ if (ycache)
4252
+ estride = Builder2.getInt32 (1 );
4253
+ else
4254
+ estride = lookup (gutils->getNewFromOriginal (call.getArgOperand (4 )),
4255
+ Builder2);
4256
+ SmallVector<Value *, 6 > args2 = {
4257
+ lookup (gutils->getNewFromOriginal (call.getArgOperand (0 )),
4258
+ Builder2),
4259
+ diffe (&call, Builder2),
4260
+ structarg2,
4261
+ estride,
4262
+ lookup (gutils->invertPointerM (call.getArgOperand (1 ), Builder2),
4263
+ Builder2),
4264
+ lookup (gutils->getNewFromOriginal (call.getArgOperand (2 )),
4265
+ Builder2)};
4266
+ seconddcall = Builder2.CreateCall (derivcall, args2);
4267
+ }
4268
+ setDiffe (&call, Constant::getNullValue (call.getType ()), Builder2);
4269
+ if (shouldFree ()) {
4270
+ if (xcache)
4271
+ CallInst::CreateFree (structarg1, firstdcall->getNextNode ());
4272
+ if (ycache)
4273
+ CallInst::CreateFree (structarg2, seconddcall->getNextNode ());
4274
+ }
4275
+ }
4276
+
4277
+ if (gutils->knownRecomputeHeuristic .find (&call) !=
4278
+ gutils->knownRecomputeHeuristic .end ()) {
4279
+ if (!gutils->knownRecomputeHeuristic [&call]) {
4280
+ gutils->cacheForReverse (BuilderZ, newCall,
4281
+ getIndex (&call, CacheType::Self));
4282
+ }
4283
+ }
4284
+
4285
+ if (Mode == DerivativeMode::ReverseModeGradient) {
4286
+ eraseIfUnused (call, /* erase*/ true , /* check*/ false );
4287
+ } else {
4288
+ eraseIfUnused (call);
4289
+ }
4290
+ return true ;
4291
+ }
4292
+ return false ;
4293
+ }
4294
+
4110
4295
void handleMPI (llvm::CallInst &call, Function *called, StringRef funcName) {
4111
4296
assert (Mode != DerivativeMode::ForwardMode);
4112
4297
assert (called);
@@ -6018,180 +6203,9 @@ class AdjointGenerator
6018
6203
return ;
6019
6204
}
6020
6205
6021
- if ((funcName == " cblas_ddot" || funcName == " cblas_sdot" ) &&
6022
- called->isDeclaration ()) {
6023
- Type *innerType;
6024
- std::string dfuncName;
6025
- if (funcName == " cblas_ddot" ) {
6026
- innerType = Type::getDoubleTy (call.getContext ());
6027
- dfuncName = " cblas_daxpy" ;
6028
- } else if (funcName == " cblas_sdot" ) {
6029
- innerType = Type::getFloatTy (call.getContext ());
6030
- dfuncName = " cblas_saxpy" ;
6031
- } else {
6032
- assert (false && " Unreachable" );
6033
- }
6034
- Type *castvals[2 ] = {call.getArgOperand (1 )->getType (),
6035
- call.getArgOperand (3 )->getType ()};
6036
- auto *cachetype =
6037
- StructType::get (call.getContext (), ArrayRef<Type *>(castvals));
6038
- Value *undefinit = UndefValue::get (cachetype);
6039
- Value *cacheval;
6040
- auto in_arg = call.getCalledFunction ()->arg_begin ();
6041
- in_arg++;
6042
- Argument *xfuncarg = in_arg;
6043
- in_arg++;
6044
- in_arg++;
6045
- Argument *yfuncarg = in_arg;
6046
- bool xcache = !gutils->isConstantValue (call.getArgOperand (3 )) &&
6047
- uncacheable_args.find (xfuncarg)->second ;
6048
- bool ycache = !gutils->isConstantValue (call.getArgOperand (1 )) &&
6049
- uncacheable_args.find (yfuncarg)->second ;
6050
- if ((Mode == DerivativeMode::ReverseModeCombined ||
6051
- Mode == DerivativeMode::ReverseModePrimal) &&
6052
- (xcache || ycache)) {
6053
- Value *arg1, *arg2;
6054
- auto size = ConstantExpr::getSizeOf (innerType);
6055
- if (xcache) {
6056
- auto dmemcpy =
6057
- getOrInsertMemcpyStrided (*gutils->oldFunc ->getParent (),
6058
- PointerType::getUnqual (innerType), 0 , 0 );
6059
- auto malins = CallInst::CreateMalloc (
6060
- gutils->getNewFromOriginal (&call), size->getType (), innerType,
6061
- size, call.getArgOperand (0 ), nullptr , " " );
6062
- arg1 =
6063
- BuilderZ.CreateBitCast (malins, call.getArgOperand (1 )->getType ());
6064
- SmallVector<Value *, 4 > args;
6065
- args.push_back (arg1);
6066
- args.push_back (gutils->getNewFromOriginal (call.getArgOperand (1 )));
6067
- args.push_back (call.getArgOperand (0 ));
6068
- args.push_back (call.getArgOperand (2 ));
6069
- BuilderZ.CreateCall (dmemcpy, args);
6070
- }
6071
- if (ycache) {
6072
- auto dmemcpy =
6073
- getOrInsertMemcpyStrided (*gutils->oldFunc ->getParent (),
6074
- PointerType::getUnqual (innerType), 0 , 0 );
6075
- auto malins = CallInst::CreateMalloc (
6076
- gutils->getNewFromOriginal (&call), size->getType (), innerType,
6077
- size, call.getArgOperand (0 ), nullptr , " " );
6078
- arg2 =
6079
- BuilderZ.CreateBitCast (malins, call.getArgOperand (3 )->getType ());
6080
- SmallVector<Value *, 4 > args;
6081
- args.push_back (arg2);
6082
- args.push_back (gutils->getNewFromOriginal (call.getArgOperand (3 )));
6083
- args.push_back (call.getArgOperand (0 ));
6084
- args.push_back (call.getArgOperand (4 ));
6085
- BuilderZ.CreateCall (dmemcpy, args);
6086
- }
6087
- if (xcache && ycache) {
6088
- auto valins1 = BuilderZ.CreateInsertValue (undefinit, arg1, 0 );
6089
- cacheval = BuilderZ.CreateInsertValue (valins1, arg2, 1 );
6090
- } else if (xcache)
6091
- cacheval = arg1;
6092
- else {
6093
- assert (ycache);
6094
- cacheval = arg2;
6095
- }
6096
- gutils->cacheForReverse (BuilderZ, cacheval,
6097
- getIndex (&call, CacheType::Tape));
6098
- }
6099
- if (Mode == DerivativeMode::ReverseModeCombined ||
6100
- Mode == DerivativeMode::ReverseModeGradient) {
6101
- IRBuilder<> Builder2 (call.getParent ());
6102
- getReverseBuilder (Builder2);
6103
- auto derivcall = gutils->oldFunc ->getParent ()->getOrInsertFunction (
6104
- dfuncName, Builder2.getVoidTy (), Builder2.getInt32Ty (), innerType,
6105
- call.getArgOperand (1 )->getType (), Builder2.getInt32Ty (),
6106
- call.getArgOperand (3 )->getType (), Builder2.getInt32Ty ());
6107
- Value *structarg1;
6108
- Value *structarg2;
6109
- if (xcache || ycache) {
6110
- if (Mode == DerivativeMode::ReverseModeGradient &&
6111
- (!gutils->isConstantValue (call.getArgOperand (1 )) ||
6112
- !gutils->isConstantValue (call.getArgOperand (3 )))) {
6113
- cacheval = BuilderZ.CreatePHI (cachetype, 0 );
6114
- }
6115
- cacheval =
6116
- lookup (gutils->cacheForReverse (BuilderZ, cacheval,
6117
- getIndex (&call, CacheType::Tape)),
6118
- Builder2);
6119
- if (xcache && ycache) {
6120
- structarg1 = BuilderZ.CreateExtractValue (cacheval, 0 );
6121
- structarg2 = BuilderZ.CreateExtractValue (cacheval, 1 );
6122
- } else if (xcache)
6123
- structarg1 = cacheval;
6124
- else if (ycache)
6125
- structarg2 = cacheval;
6126
- }
6127
- if (!xcache)
6128
- structarg1 = lookup (
6129
- gutils->getNewFromOriginal (orig->getArgOperand (1 )), Builder2);
6130
- if (!ycache)
6131
- structarg2 = lookup (
6132
- gutils->getNewFromOriginal (orig->getArgOperand (3 )), Builder2);
6133
- CallInst *firstdcall, *seconddcall;
6134
- if (!gutils->isConstantValue (call.getArgOperand (3 ))) {
6135
- Value *estride;
6136
- if (xcache)
6137
- estride = Builder2.getInt32 (1 );
6138
- else
6139
- estride = lookup (gutils->getNewFromOriginal (orig->getArgOperand (2 )),
6140
- Builder2);
6141
- SmallVector<Value *, 6 > args1 = {
6142
- lookup (gutils->getNewFromOriginal (orig->getArgOperand (0 )),
6143
- Builder2),
6144
- diffe (orig, Builder2),
6145
- structarg1,
6146
- estride,
6147
- lookup (gutils->invertPointerM (orig->getArgOperand (3 ), Builder2),
6148
- Builder2),
6149
- lookup (gutils->getNewFromOriginal (orig->getArgOperand (4 )),
6150
- Builder2)};
6151
- firstdcall = Builder2.CreateCall (derivcall, args1);
6152
- }
6153
- if (!gutils->isConstantValue (call.getArgOperand (1 ))) {
6154
- Value *estride;
6155
- if (ycache)
6156
- estride = Builder2.getInt32 (1 );
6157
- else
6158
- estride = lookup (gutils->getNewFromOriginal (orig->getArgOperand (4 )),
6159
- Builder2);
6160
- SmallVector<Value *, 6 > args2 = {
6161
- lookup (gutils->getNewFromOriginal (orig->getArgOperand (0 )),
6162
- Builder2),
6163
- diffe (orig, Builder2),
6164
- structarg2,
6165
- estride,
6166
- lookup (gutils->invertPointerM (orig->getArgOperand (1 ), Builder2),
6167
- Builder2),
6168
- lookup (gutils->getNewFromOriginal (orig->getArgOperand (2 )),
6169
- Builder2)};
6170
- seconddcall = Builder2.CreateCall (derivcall, args2);
6171
- }
6172
- setDiffe (orig, Constant::getNullValue (orig->getType ()), Builder2);
6173
- if (shouldFree ()) {
6174
- if (xcache)
6175
- CallInst::CreateFree (structarg1, firstdcall->getNextNode ());
6176
- if (ycache)
6177
- CallInst::CreateFree (structarg2, seconddcall->getNextNode ());
6178
- }
6179
- }
6180
-
6181
- if (gutils->knownRecomputeHeuristic .find (orig) !=
6182
- gutils->knownRecomputeHeuristic .end ()) {
6183
- if (!gutils->knownRecomputeHeuristic [orig]) {
6184
- gutils->cacheForReverse (BuilderZ, newCall,
6185
- getIndex (orig, CacheType::Self));
6186
- }
6187
- }
6188
-
6189
- if (Mode == DerivativeMode::ReverseModeGradient) {
6190
- eraseIfUnused (*orig, /* erase*/ true , /* check*/ false );
6191
- } else {
6192
- eraseIfUnused (*orig);
6193
- }
6194
- return ;
6206
+ if ((funcName == " cblas_ddot" || funcName == " cblas_sdot" )) {
6207
+ if (handleBLAS (call, called, funcName, uncacheable_args))
6208
+ return ;
6195
6209
}
6196
6210
6197
6211
if (funcName == " printf" || funcName == " puts" ||
0 commit comments