Skip to content

Commit b8d5af7

Browse files
committed
add all of cmath to type analysis
1 parent 63b7990 commit b8d5af7

File tree

1 file changed

+340
-2
lines changed

1 file changed

+340
-2
lines changed

enzyme/Enzyme/TypeAnalysis.cpp

Lines changed: 340 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,6 +1157,177 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) {
11571157
}
11581158
}
11591159

1160+
template<typename T>
1161+
struct Meta {
1162+
};
1163+
1164+
template<>
1165+
struct Meta<double> {
1166+
static void analyzeType(Value* val, CallInst &call, TypeAnalyzer &TA) {
1167+
TA.updateAnalysis(val, DataType(Type::getDoubleTy(call.getContext())), &call);
1168+
}
1169+
};
1170+
1171+
template<>
1172+
struct Meta<float> {
1173+
static void analyzeType(Value* val, CallInst &call, TypeAnalyzer &TA) {
1174+
TA.updateAnalysis(val, DataType(Type::getFloatTy(call.getContext())), &call);
1175+
}
1176+
};
1177+
1178+
template<>
1179+
struct Meta<long double> {
1180+
static void analyzeType(Value* val, CallInst &call, TypeAnalyzer &TA) {
1181+
TA.updateAnalysis(val, DataType(Type::getX86_FP80Ty(call.getContext())), &call);
1182+
}
1183+
};
1184+
1185+
template<>
1186+
struct Meta<__float128> {
1187+
static void analyzeType(Value* val, CallInst &call, TypeAnalyzer &TA) {
1188+
TA.updateAnalysis(val, DataType(Type::getFP128Ty (call.getContext())), &call);
1189+
}
1190+
};
1191+
1192+
template<>
1193+
struct Meta<double*> {
1194+
static void analyzeType(Value* val, CallInst &call, TypeAnalyzer &TA) {
1195+
ValueData vd = ValueData(Type::getDoubleTy(call.getContext())).Only({0});
1196+
vd |= ValueData(IntType::Pointer);
1197+
TA.updateAnalysis(val, vd, &call);
1198+
}
1199+
};
1200+
1201+
template<>
1202+
struct Meta<float*> {
1203+
static void analyzeType(Value* val, CallInst &call, TypeAnalyzer &TA) {
1204+
ValueData vd = ValueData(Type::getFloatTy(call.getContext())).Only({0});
1205+
vd |= ValueData(IntType::Pointer);
1206+
TA.updateAnalysis(val, vd, &call);
1207+
}
1208+
};
1209+
1210+
template<>
1211+
struct Meta<long double*> {
1212+
static void analyzeType(Value* val, CallInst &call, TypeAnalyzer &TA) {
1213+
ValueData vd = ValueData(Type::getX86_FP80Ty(call.getContext())).Only({0});
1214+
vd |= ValueData(IntType::Pointer);
1215+
TA.updateAnalysis(val, vd, &call);
1216+
}
1217+
};
1218+
1219+
template<>
1220+
struct Meta<__float128*> {
1221+
static void analyzeType(Value* val, CallInst &call, TypeAnalyzer &TA) {
1222+
ValueData vd = ValueData(Type::getFP128Ty(call.getContext())).Only({0});
1223+
vd |= ValueData(IntType::Pointer);
1224+
TA.updateAnalysis(val, vd, &call);
1225+
}
1226+
};
1227+
1228+
1229+
template<>
1230+
struct Meta<void> {
1231+
static void analyzeType(Value* val, CallInst &call, TypeAnalyzer &TA) {
1232+
}
1233+
};
1234+
1235+
template<>
1236+
struct Meta<void*> {
1237+
static void analyzeType(Value* val, CallInst &call, TypeAnalyzer &TA) {
1238+
ValueData vd = ValueData(IntType::Pointer);
1239+
TA.updateAnalysis(val, vd, &call);
1240+
}
1241+
};
1242+
1243+
template<>
1244+
struct Meta<int> {
1245+
static void analyzeType(Value* val, CallInst &call, TypeAnalyzer &TA) {
1246+
ValueData vd = ValueData(IntType::Integer);
1247+
TA.updateAnalysis(val, vd, &call);
1248+
}
1249+
};
1250+
1251+
template<>
1252+
struct Meta<int*> {
1253+
static void analyzeType(Value* val, CallInst &call, TypeAnalyzer &TA) {
1254+
ValueData vd = ValueData(IntType::Integer).Only({0});
1255+
vd |= ValueData(IntType::Pointer);
1256+
TA.updateAnalysis(val, vd, &call);
1257+
}
1258+
};
1259+
1260+
template<>
1261+
struct Meta<long int> {
1262+
static void analyzeType(Value* val, CallInst &call, TypeAnalyzer &TA) {
1263+
ValueData vd = ValueData(IntType::Integer);
1264+
TA.updateAnalysis(val, vd, &call);
1265+
}
1266+
};
1267+
1268+
template<>
1269+
struct Meta<long int*> {
1270+
static void analyzeType(Value* val, CallInst &call, TypeAnalyzer &TA) {
1271+
ValueData vd = ValueData(IntType::Integer).Only({0});
1272+
vd |= ValueData(IntType::Pointer);
1273+
TA.updateAnalysis(val, vd, &call);
1274+
}
1275+
};
1276+
1277+
template<>
1278+
struct Meta<long unsigned int> {
1279+
static void analyzeType(Value* val, CallInst &call, TypeAnalyzer &TA) {
1280+
ValueData vd = ValueData(IntType::Integer);
1281+
TA.updateAnalysis(val, vd, &call);
1282+
}
1283+
};
1284+
1285+
template<>
1286+
struct Meta<long unsigned int*> {
1287+
static void analyzeType(Value* val, CallInst &call, TypeAnalyzer &TA) {
1288+
ValueData vd = ValueData(IntType::Integer).Only({0});
1289+
vd |= ValueData(IntType::Pointer);
1290+
TA.updateAnalysis(val, vd, &call);
1291+
}
1292+
};
1293+
1294+
template<>
1295+
struct Meta<long long int> {
1296+
static void analyzeType(Value* val, CallInst &call, TypeAnalyzer &TA) {
1297+
ValueData vd = ValueData(IntType::Integer);
1298+
TA.updateAnalysis(val, vd, &call);
1299+
}
1300+
};
1301+
1302+
template<>
1303+
struct Meta<long long int*> {
1304+
static void analyzeType(Value* val, CallInst &call, TypeAnalyzer &TA) {
1305+
ValueData vd = ValueData(IntType::Integer).Only({0});
1306+
vd |= ValueData(IntType::Pointer);
1307+
TA.updateAnalysis(val, vd, &call);
1308+
}
1309+
};
1310+
1311+
template<typename... Arg0>
1312+
struct FunctionTemplatesSuck {
1313+
static void analyzeFuncTypesHelper(unsigned idx, CallInst& call, TypeAnalyzer& TA) {}
1314+
};
1315+
1316+
template<typename Arg0, typename... Args>
1317+
struct FunctionTemplatesSuck<Arg0, Args...> {
1318+
static void analyzeFuncTypesHelper(unsigned idx, CallInst& call, TypeAnalyzer& TA) {
1319+
Meta<Arg0>::analyzeType(call.getOperand(idx), call, TA);
1320+
FunctionTemplatesSuck<Args...>::analyzeFuncTypesHelper(idx+1, call, TA);
1321+
}
1322+
};
1323+
1324+
1325+
template<typename RT, typename... Args>
1326+
void analyzeFuncTypes( RT(*fn)(Args...), CallInst& call, TypeAnalyzer& TA) {
1327+
Meta<RT>::analyzeType(&call, call, TA);
1328+
FunctionTemplatesSuck<Args...>::analyzeFuncTypesHelper(0, call, TA);
1329+
}
1330+
11601331
void TypeAnalyzer::visitCallInst(CallInst &call) {
11611332
if (auto iasm = dyn_cast<InlineAsm>(call.getCalledValue())) {
11621333
if (iasm->getAsmString() == "cpuid") {
@@ -1169,21 +1340,188 @@ void TypeAnalyzer::visitCallInst(CallInst &call) {
11691340

11701341
if (Function* ci = call.getCalledFunction()) {
11711342

1343+
#define CONSIDER(fn)\
1344+
if (ci->getName() == #fn) {\
1345+
analyzeFuncTypes(::fn, call, *this);\
1346+
return;\
1347+
}
1348+
1349+
CONSIDER(malloc)
1350+
//CONSIDER(__lgamma_r_finite)
1351+
CONSIDER(frexp)
1352+
CONSIDER(frexpf)
1353+
CONSIDER(frexpl)
1354+
CONSIDER(ldexp)
1355+
CONSIDER(modf)
1356+
1357+
CONSIDER(cos)
1358+
CONSIDER(sin)
1359+
CONSIDER(tan)
1360+
CONSIDER(acos)
1361+
CONSIDER(asin)
1362+
CONSIDER(atan)
1363+
CONSIDER(atan2)
1364+
CONSIDER(cosh)
1365+
CONSIDER(sinh)
1366+
CONSIDER(tanh)
1367+
CONSIDER(acosh)
1368+
CONSIDER(acoshf)
1369+
CONSIDER(acoshl)
1370+
CONSIDER(asinh)
1371+
CONSIDER(asinhf)
1372+
CONSIDER(asinhl)
1373+
CONSIDER(atanh)
1374+
CONSIDER(atanhl)
1375+
CONSIDER(atanhf)
1376+
CONSIDER(exp)
1377+
CONSIDER(log)
1378+
CONSIDER(log10)
1379+
CONSIDER(exp2)
1380+
CONSIDER(exp2f)
1381+
CONSIDER(exp2l)
1382+
CONSIDER(log10)
1383+
CONSIDER(exp2)
1384+
CONSIDER(expm1)
1385+
CONSIDER(expm1f)
1386+
CONSIDER(expm1l)
1387+
CONSIDER(ilogb)
1388+
CONSIDER(ilogbf)
1389+
CONSIDER(ilogbl)
1390+
CONSIDER(log1p)
1391+
CONSIDER(log1pf)
1392+
CONSIDER(log1pl)
1393+
CONSIDER(log2)
1394+
CONSIDER(log2f)
1395+
CONSIDER(log2l)
1396+
CONSIDER(logb)
1397+
CONSIDER(logbf)
1398+
CONSIDER(logbl)
1399+
CONSIDER(scalbn)
1400+
CONSIDER(scalbnf)
1401+
CONSIDER(scalbnl)
1402+
CONSIDER(scalbln)
1403+
CONSIDER(scalblnf)
1404+
CONSIDER(scalblnl)
1405+
CONSIDER(pow)
1406+
CONSIDER(sqrt)
1407+
CONSIDER(cbrt)
1408+
CONSIDER(cbrtf)
1409+
CONSIDER(cbrtl)
1410+
CONSIDER(hypot)
1411+
CONSIDER(erf)
1412+
CONSIDER(erff)
1413+
CONSIDER(erfl)
1414+
CONSIDER(erfc)
1415+
CONSIDER(erfcf)
1416+
CONSIDER(erfcl)
1417+
CONSIDER(tgamma)
1418+
CONSIDER(tgammaf)
1419+
CONSIDER(tgammal)
1420+
CONSIDER(lgamma)
1421+
CONSIDER(lgammaf)
1422+
CONSIDER(lgammal)
1423+
CONSIDER(ceil)
1424+
CONSIDER(floor)
1425+
CONSIDER(fmod)
1426+
CONSIDER(trunc)
1427+
CONSIDER(truncf)
1428+
CONSIDER(truncl)
1429+
CONSIDER(round)
1430+
CONSIDER(roundf)
1431+
CONSIDER(roundl)
1432+
CONSIDER(lround)
1433+
CONSIDER(lroundf)
1434+
CONSIDER(lroundl)
1435+
CONSIDER(llround)
1436+
CONSIDER(llroundf)
1437+
CONSIDER(llroundl)
1438+
CONSIDER(rint)
1439+
CONSIDER(rintf)
1440+
CONSIDER(rintl)
1441+
CONSIDER(lrint)
1442+
CONSIDER(lrintf)
1443+
CONSIDER(lrintl)
1444+
CONSIDER(llrint)
1445+
CONSIDER(llrintf)
1446+
CONSIDER(llrintl)
1447+
CONSIDER(remainder)
1448+
CONSIDER(remainderf)
1449+
CONSIDER(remainderl)
1450+
CONSIDER(remquo)
1451+
CONSIDER(remquof)
1452+
CONSIDER(remquol)
1453+
CONSIDER(copysign)
1454+
CONSIDER(copysignf)
1455+
CONSIDER(copysignl)
1456+
CONSIDER(nextafter)
1457+
CONSIDER(nextafterf)
1458+
CONSIDER(nextafterl)
1459+
CONSIDER(nexttoward)
1460+
CONSIDER(nexttowardf)
1461+
CONSIDER(nexttowardl)
1462+
CONSIDER(fdim)
1463+
CONSIDER(fdimf)
1464+
CONSIDER(fdiml)
1465+
CONSIDER(fmax)
1466+
CONSIDER(fmaxf)
1467+
CONSIDER(fmaxl)
1468+
CONSIDER(fmin)
1469+
CONSIDER(fminf)
1470+
CONSIDER(fminl)
1471+
CONSIDER(fabs)
1472+
CONSIDER(fma)
1473+
CONSIDER(fmaf)
1474+
CONSIDER(fmal)
1475+
1476+
1477+
if (ci->getName() == "__lgamma_r_finite") {
1478+
updateAnalysis(call.getArgOperand(0), DataType(Type::getDoubleTy(call.getContext())), &call);
1479+
updateAnalysis(call.getArgOperand(1), ValueData(IntType::Integer).Only({0}), &call);
1480+
updateAnalysis(&call, DataType(Type::getDoubleTy(call.getContext())), &call);
1481+
}
1482+
1483+
/*
11721484
if (ci->getName() == "malloc") {
11731485
updateAnalysis(call.getArgOperand(0), IntType::Integer, &call);
11741486
}
11751487
1176-
if (ci->getName() == "__lgamma_r_finite") {
1488+
1489+
1490+
if (ci->getName() == "frexp") {
11771491
updateAnalysis(call.getArgOperand(0), DataType(Type::getDoubleTy(call.getContext())), &call);
11781492
updateAnalysis(call.getArgOperand(1), ValueData(IntType::Integer).Only({0}), &call);
11791493
updateAnalysis(&call, DataType(Type::getDoubleTy(call.getContext())), &call);
11801494
}
11811495
1182-
if (ci->getName() == "tanh") {
1496+
if (ci->getName() == "ldexp") {
1497+
updateAnalysis(call.getArgOperand(0), DataType(Type::getDoubleTy(call.getContext())), &call);
1498+
updateAnalysis(call.getArgOperand(1), ValueData(IntType::Integer), &call);
1499+
updateAnalysis(&call, DataType(Type::getDoubleTy(call.getContext())), &call);
1500+
}
1501+
1502+
if (ci->getName() == "modf") {
11831503
updateAnalysis(call.getArgOperand(0), DataType(Type::getDoubleTy(call.getContext())), &call);
1504+
updateAnalysis(call.getArgOperand(1), ValueData(Type::getDoubleTy(call.getContext())).Only({0}), &call);
11841505
updateAnalysis(&call, DataType(Type::getDoubleTy(call.getContext())), &call);
11851506
}
11861507
1508+
if (ci->getName() == "sin")
1509+
analyzeFuncTypes(sin, call, *this);
1510+
1511+
const std::vector<std::string> doubleCmath = {
1512+
"cos", "sin", "tan", "acos", "asin", "atan", "atan2",
1513+
"cosh", "sinh", "tanh", "acosh", "asinh", "atanh",
1514+
"exp", "log", "log10", "exp2", "expm1"
1515+
};
1516+
const std::vector<std::string> floatCMath = {"acoshf", "asinhf", "atanhf", "exp2f", "expm1f"};
1517+
const std::vector<std::string> longDoubleCMath = {"acoshl", "asinhl", "atanhl", "exp2l", "expm1l"};
1518+
1519+
if (std::find(doubleCmath.begin(), doubleCmath.end(), ci->getName().str()) != doubleCmath.end()) {
1520+
for(unsigned i=0; i<call.getNumArgOperands(); i++)
1521+
updateAnalysis(call.getArgOperand(i), DataType(Type::getDoubleTy(call.getContext())), &call);
1522+
updateAnalysis(&call, DataType(Type::getDoubleTy(call.getContext())), &call);
1523+
}*/
1524+
11871525
//TODO we should handle calls interprocedurally, allowing better propagation of type information
11881526
if (!ci->empty()) {
11891527
visitIPOCall(call, *ci);

0 commit comments

Comments
 (0)