Skip to content

Commit 8243166

Browse files
committed
[SYCL-MLIR] Improve code generation for the rest of binary operators
9ac759f improved code generation for addition and subtraction. This commit addresses the rest of binary arithmetic operators (mul, div, rem, shl, shr, and, or, xor): - Enable code generatio for mul and div operators and vector operands; - Mask shl and shr operands in OpenCL - Always generate signed/unsigned operations for vector operands. Signed-off-by: Victor Perez <victor.perez@codeplay.com>
1 parent 02dbbfd commit 8243166

File tree

14 files changed

+2316
-92
lines changed

14 files changed

+2316
-92
lines changed

polygeist/tools/cgeist/Lib/CGExpr.cc

Lines changed: 133 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -2098,18 +2098,6 @@ ValueCategory MLIRScanner::VisitBinAssign(BinaryOperator *E) {
20982098
return RHS;
20992099
}
21002100

2101-
static bool isSigned(QualType Ty) {
2102-
// TODO note assumptions made here about unsigned / unordered
2103-
bool SignedType = true;
2104-
if (const auto *Bit = dyn_cast<clang::BuiltinType>(Ty)) {
2105-
if (Bit->isUnsignedInteger())
2106-
SignedType = false;
2107-
if (Bit->isSignedInteger())
2108-
SignedType = true;
2109-
}
2110-
return SignedType;
2111-
}
2112-
21132101
class BinOpInfo {
21142102
public:
21152103
BinOpInfo(ValueCategory LHS, ValueCategory RHS, QualType Ty,
@@ -2577,39 +2565,62 @@ static void informNoOverflowCheck(LangOptions::SignedOverflowBehaviorTy SOB,
25772565
}
25782566

25792567
ValueCategory MLIRScanner::EmitBinMul(const BinOpInfo &Info) {
2580-
auto LHSVal = Info.getLHS().getValue(Builder);
2581-
auto RHSVal = Info.getRHS().getValue(Builder);
2582-
if (LHSVal.getType().isa<mlir::FloatType>())
2583-
return ValueCategory(Builder.create<arith::MulFOp>(Loc, LHSVal, RHSVal),
2584-
/*isReference*/ false);
2585-
return ValueCategory(Builder.create<arith::MulIOp>(Loc, LHSVal, RHSVal),
2586-
/*isReference*/ false);
2568+
const auto Loc = getMLIRLocation(Info.getExpr()->getExprLoc());
2569+
const auto LHS = Info.getLHS();
2570+
const auto RHS = Info.getRHS().val;
2571+
2572+
if (Info.getType()->isSignedIntegerOrEnumerationType()) {
2573+
informNoOverflowCheck(
2574+
Glob.getCGM().getLangOpts().getSignedOverflowBehavior(), "mul");
2575+
return LHS.Mul(Builder, Loc, RHS);
2576+
}
2577+
2578+
assert(!Info.getType()->isConstantMatrixType() && "Not yet implemented");
2579+
2580+
if (mlirclang::isFPOrFPVectorTy(LHS.val.getType()))
2581+
return LHS.FMul(Builder, Loc, RHS);
2582+
return LHS.Mul(Builder, Loc, RHS);
25872583
}
25882584

25892585
ValueCategory MLIRScanner::EmitBinDiv(const BinOpInfo &Info) {
2590-
auto LHSVal = Info.getLHS().getValue(Builder);
2591-
auto RHSVal = Info.getRHS().getValue(Builder);
2592-
if (LHSVal.getType().isa<mlir::FloatType>())
2593-
return ValueCategory(Builder.create<arith::DivFOp>(Loc, LHSVal, RHSVal),
2594-
/*isReference*/ false);
2595-
if (isSigned(Info.getType()))
2596-
return ValueCategory(Builder.create<arith::DivSIOp>(Loc, LHSVal, RHSVal),
2597-
/*isReference*/ false);
2598-
return ValueCategory(Builder.create<arith::DivUIOp>(Loc, LHSVal, RHSVal),
2599-
/*isReference*/ false);
2586+
mlirclang::warning()
2587+
<< "Not checking division by zero nor signed integer overflow.\n";
2588+
2589+
assert(!Info.getType()->isConstantMatrixType() && "Not implemented");
2590+
2591+
const auto Loc = getMLIRLocation(Info.getExpr()->getExprLoc());
2592+
const auto LHS = Info.getLHS();
2593+
const auto RHS = Info.getRHS().val;
2594+
if (mlirclang::isFPOrFPVectorTy(LHS.val.getType())) {
2595+
const auto &LangOpts = Glob.getCGM().getLangOpts();
2596+
const auto &CodeGenOpts = Glob.getCGM().getCodeGenOpts();
2597+
if ((LangOpts.OpenCL && !CodeGenOpts.OpenCLCorrectlyRoundedDivSqrt) ||
2598+
(LangOpts.HIP && LangOpts.CUDAIsDevice &&
2599+
!CodeGenOpts.HIPCorrectlyRoundedDivSqrt)) {
2600+
// OpenCL v1.1 s7.4: minimum accuracy of single precision / is 2.5ulp
2601+
// OpenCL v1.2 s5.6.4.2: The -cl-fp32-correctly-rounded-divide-sqrt
2602+
// build option allows an application to specify that single precision
2603+
// floating-point divide (x/y and 1/x) and sqrt used in the program
2604+
// source are correctly rounded.
2605+
mlirclang::warning() << "Not applying OpenCL/HIP precision options.\n";
2606+
}
2607+
return LHS.FDiv(Builder, Loc, RHS);
2608+
}
2609+
if (Info.getType()->hasUnsignedIntegerRepresentation())
2610+
return LHS.UDiv(Builder, Loc, RHS);
2611+
return LHS.SDiv(Builder, Loc, RHS);
26002612
}
26012613

26022614
ValueCategory MLIRScanner::EmitBinRem(const BinOpInfo &Info) {
2603-
auto LHSVal = Info.getLHS().getValue(Builder);
2604-
auto RHSVal = Info.getRHS().getValue(Builder);
2605-
if (LHSVal.getType().isa<mlir::FloatType>())
2606-
return ValueCategory(Builder.create<arith::RemFOp>(Loc, LHSVal, RHSVal),
2607-
/*isReference*/ false);
2608-
if (isSigned(Info.getType()))
2609-
return ValueCategory(Builder.create<arith::RemSIOp>(Loc, LHSVal, RHSVal),
2610-
/*isReference*/ false);
2611-
return ValueCategory(Builder.create<arith::RemUIOp>(Loc, LHSVal, RHSVal),
2612-
/*isReference*/ false);
2615+
mlirclang::warning()
2616+
<< "Not checking division by zero nor signed integer overflow.\n";
2617+
2618+
const auto Loc = getMLIRLocation(Info.getExpr()->getExprLoc());
2619+
const auto LHS = Info.getLHS();
2620+
const auto RHS = Info.getRHS().val;
2621+
if (Info.getType()->hasUnsignedIntegerRepresentation())
2622+
return LHS.URem(Builder, Loc, RHS);
2623+
return LHS.SRem(Builder, Loc, RHS);
26132624
}
26142625

26152626
/// Casts index of subindex operation conditionally.
@@ -2821,57 +2832,102 @@ ValueCategory MLIRScanner::EmitBinSub(const BinOpInfo &Info) {
28212832
return DiffInChars.ExactSDiv(Builder, Loc, Divisor);
28222833
}
28232834

2835+
static mlir::Value GetWidthMinusOneValue(mlir::OpBuilder &Builder,
2836+
mlir::Location Loc, mlir::Value LHS,
2837+
mlir::Value RHS) {
2838+
auto Ty = LHS.getType();
2839+
IntegerType IntTy;
2840+
if (auto VT = Ty.dyn_cast<mlir::VectorType>())
2841+
IntTy = VT.getElementType().cast<IntegerType>();
2842+
else
2843+
IntTy = Ty.cast<IntegerType>();
2844+
2845+
const auto WidthMinusOne = IntTy.getWidth() - 1;
2846+
ValueCategory Val{
2847+
Builder.createOrFold<arith::ConstantIntOp>(Loc, WidthMinusOne, IntTy),
2848+
false};
2849+
if (auto VT = Ty.dyn_cast<mlir::VectorType>())
2850+
Val = Val.Splat(Builder, Loc, VT);
2851+
return Val.val;
2852+
}
2853+
2854+
ValueCategory MLIRScanner::ConstrainShiftValue(ValueCategory LHS,
2855+
ValueCategory RHS) {
2856+
IntegerType Ty;
2857+
if (auto VT = LHS.val.getType().dyn_cast<mlir::VectorType>())
2858+
Ty = VT.getElementType().cast<IntegerType>();
2859+
else
2860+
Ty = LHS.val.getType().cast<IntegerType>();
2861+
2862+
if (llvm::isPowerOf2_64(Ty.getWidth()))
2863+
return RHS.And(Builder, Loc,
2864+
GetWidthMinusOneValue(Builder, Loc, LHS.val, RHS.val));
2865+
return RHS.URem(Builder, Loc,
2866+
Builder.createOrFold<arith::ConstantIntOp>(
2867+
Loc, Ty.getWidth(), RHS.val.getType()));
2868+
}
2869+
28242870
ValueCategory MLIRScanner::EmitBinShl(const BinOpInfo &Info) {
2825-
auto LHSVal = Info.getLHS().getValue(Builder);
2826-
auto RHSVal = Info.getRHS().getValue(Builder);
2827-
auto PrevTy = RHSVal.getType().cast<mlir::IntegerType>();
2828-
auto PostTy = LHSVal.getType().cast<mlir::IntegerType>();
2829-
if (PrevTy.getWidth() < PostTy.getWidth())
2830-
RHSVal = Builder.create<arith::ExtUIOp>(Loc, PostTy, RHSVal);
2831-
if (PrevTy.getWidth() > PostTy.getWidth())
2832-
RHSVal = Builder.create<arith::TruncIOp>(Loc, PostTy, RHSVal);
2833-
assert(LHSVal.getType() == RHSVal.getType());
2834-
return ValueCategory(Builder.create<arith::ShLIOp>(Loc, LHSVal, RHSVal),
2835-
/*isReference*/ false);
2871+
const auto Loc = getMLIRLocation(Info.getExpr()->getExprLoc());
2872+
auto LHS = Info.getLHS();
2873+
auto RHS = Info.getRHS();
2874+
2875+
// LLVM requires the LHS and RHS to be the same type: promote or truncate the
2876+
// RHS to the same size as the LHS.
2877+
if (LHS.val.getType() != RHS.val.getType())
2878+
RHS = RHS.IntCast(Builder, Loc, LHS.val.getType(), /*IsSigned*/ false);
2879+
2880+
if (Glob.getCGM().getLangOpts().OpenCL) {
2881+
this->Loc = Loc;
2882+
RHS = ConstrainShiftValue(LHS, RHS);
2883+
} else {
2884+
mlirclang::warning() << "Not performing SHL checks\n";
2885+
}
2886+
2887+
return LHS.Shl(Builder, Loc, RHS.val);
28362888
}
28372889

28382890
ValueCategory MLIRScanner::EmitBinShr(const BinOpInfo &Info) {
2839-
auto LHSVal = Info.getLHS().getValue(Builder);
2840-
auto RHSVal = Info.getRHS().getValue(Builder);
2841-
auto PrevTy = RHSVal.getType().cast<mlir::IntegerType>();
2842-
auto PostTy = LHSVal.getType().cast<mlir::IntegerType>();
2843-
if (PrevTy.getWidth() < PostTy.getWidth())
2844-
RHSVal = Builder.create<mlir::arith::ExtUIOp>(Loc, PostTy, RHSVal);
2845-
if (PrevTy.getWidth() > PostTy.getWidth())
2846-
RHSVal = Builder.create<mlir::arith::TruncIOp>(Loc, PostTy, RHSVal);
2847-
assert(LHSVal.getType() == RHSVal.getType());
2848-
if (isSigned(Info.getExpr()->getType()))
2849-
return ValueCategory(Builder.create<arith::ShRSIOp>(Loc, LHSVal, RHSVal),
2850-
/*isReference*/ false);
2851-
return ValueCategory(Builder.create<arith::ShRUIOp>(Loc, LHSVal, RHSVal),
2852-
/*isReference*/ false);
2891+
const auto Loc = getMLIRLocation(Info.getExpr()->getExprLoc());
2892+
auto LHS = Info.getLHS();
2893+
auto RHS = Info.getRHS();
2894+
2895+
// LLVM requires the LHS and RHS to be the same type: promote or truncate the
2896+
// RHS to the same size as the LHS.
2897+
if (LHS.val.getType() != RHS.val.getType())
2898+
RHS = RHS.IntCast(Builder, Loc, LHS.val.getType(), /*IsSigned*/ false);
2899+
2900+
if (Glob.getCGM().getLangOpts().OpenCL) {
2901+
this->Loc = Loc;
2902+
RHS = ConstrainShiftValue(LHS, RHS);
2903+
} else {
2904+
mlirclang::warning() << "Not performing SHL checks\n";
2905+
}
2906+
2907+
if (Info.getType()->hasUnsignedIntegerRepresentation())
2908+
return LHS.LShr(Builder, Loc, RHS.val);
2909+
return LHS.AShr(Builder, Loc, RHS.val);
28532910
}
28542911

28552912
ValueCategory MLIRScanner::EmitBinAnd(const BinOpInfo &Info) {
2856-
return ValueCategory(
2857-
Builder.create<arith::AndIOp>(Loc, Info.getLHS().getValue(Builder),
2858-
Info.getRHS().getValue(Builder)),
2859-
/*isReference*/ false);
2913+
const auto Loc = getMLIRLocation(Info.getExpr()->getExprLoc());
2914+
auto LHS = Info.getLHS();
2915+
auto RHS = Info.getRHS();
2916+
return LHS.And(Builder, Loc, RHS.val);
28602917
}
28612918

28622919
ValueCategory MLIRScanner::EmitBinXor(const BinOpInfo &Info) {
2863-
return ValueCategory(
2864-
Builder.create<arith::XOrIOp>(Loc, Info.getLHS().getValue(Builder),
2865-
Info.getRHS().getValue(Builder)),
2866-
/*isReference*/ false);
2920+
const auto Loc = getMLIRLocation(Info.getExpr()->getExprLoc());
2921+
auto LHS = Info.getLHS();
2922+
auto RHS = Info.getRHS();
2923+
return LHS.Xor(Builder, Loc, RHS.val);
28672924
}
28682925

28692926
ValueCategory MLIRScanner::EmitBinOr(const BinOpInfo &Info) {
2870-
// TODO short circuit
2871-
return ValueCategory(
2872-
Builder.create<arith::OrIOp>(Loc, Info.getLHS().getValue(Builder),
2873-
Info.getRHS().getValue(Builder)),
2874-
/*isReference*/ false);
2927+
const auto Loc = getMLIRLocation(Info.getExpr()->getExprLoc());
2928+
auto LHS = Info.getLHS();
2929+
auto RHS = Info.getRHS();
2930+
return LHS.Or(Builder, Loc, RHS.val);
28752931
}
28762932

28772933
#define HANDLEUNARYOP(OP) \

polygeist/tools/cgeist/Lib/ValueCategory.cc

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,13 @@ template <typename OpTy> inline void warnUnconstrainedOp() {
323323
<< "\n";
324324
}
325325

326+
template <typename OpTy> inline void warnNonExactOp(bool IsExact) {
327+
if (!IsExact)
328+
return;
329+
mlirclang::warning() << "Creating exact " << OpTy::getOperationName()
330+
<< " is not suported.\n";
331+
}
332+
326333
ValueCategory ValueCategory::FPTrunc(OpBuilder &Builder, Location Loc,
327334
Type PromotionType) const {
328335
assert(val.getType().isa<FloatType>() &&
@@ -569,10 +576,37 @@ static ValueCategory NUWNSWBinOp(mlir::OpBuilder &Builder, mlir::Location Loc,
569576
return IntBinOp<OpTy>(Builder, Loc, LHS, RHS);
570577
}
571578

579+
ValueCategory ValueCategory::Mul(OpBuilder &Builder, Location Loc, Value RHS,
580+
bool HasNUW, bool HasNSW) const {
581+
return NUWNSWBinOp<arith::MulIOp>(Builder, Loc, val, RHS, HasNUW, HasNSW);
582+
}
583+
584+
ValueCategory ValueCategory::FMul(OpBuilder &Builder, Location Loc,
585+
Value RHS) const {
586+
warnUnconstrainedOp<arith::DivFOp>();
587+
return FPBinOp<arith::MulFOp>(Builder, Loc, val, RHS);
588+
}
589+
590+
ValueCategory ValueCategory::FDiv(OpBuilder &Builder, Location Loc,
591+
Value RHS) const {
592+
warnUnconstrainedOp<arith::DivFOp>();
593+
return FPBinOp<arith::DivFOp>(Builder, Loc, val, RHS);
594+
}
595+
596+
ValueCategory ValueCategory::UDiv(OpBuilder &Builder, Location Loc, Value RHS,
597+
bool IsExact) const {
598+
warnNonExactOp<arith::DivUIOp>(IsExact);
599+
return IntBinOp<arith::DivUIOp>(Builder, Loc, val, RHS);
600+
}
601+
602+
ValueCategory ValueCategory::ExactUDiv(OpBuilder &Builder, Location Loc,
603+
Value RHS) const {
604+
return UDiv(Builder, Loc, RHS, /*IsExact*/ true);
605+
}
606+
572607
ValueCategory ValueCategory::SDiv(OpBuilder &Builder, Location Loc, Value RHS,
573608
bool IsExact) const {
574-
if (IsExact)
575-
mlirclang::warning() << "Creating exact division is not supported\n";
609+
warnNonExactOp<arith::DivSIOp>(IsExact);
576610
return IntBinOp<arith::DivSIOp>(Builder, Loc, val, RHS);
577611
}
578612

@@ -581,6 +615,16 @@ ValueCategory ValueCategory::ExactSDiv(OpBuilder &Builder, Location Loc,
581615
return SDiv(Builder, Loc, RHS, /*IsExact*/ true);
582616
}
583617

618+
ValueCategory ValueCategory::URem(OpBuilder &Builder, Location Loc,
619+
Value RHS) const {
620+
return IntBinOp<arith::RemUIOp>(Builder, Loc, val, RHS);
621+
}
622+
623+
ValueCategory ValueCategory::SRem(OpBuilder &Builder, Location Loc,
624+
Value RHS) const {
625+
return IntBinOp<arith::RemSIOp>(Builder, Loc, val, RHS);
626+
}
627+
584628
ValueCategory ValueCategory::Neg(OpBuilder &Builder, Location Loc, bool HasNUW,
585629
bool HasNSW) const {
586630
ValueCategory Zero(Builder.createOrFold<ConstantIntOp>(Loc, 0, val.getType()),
@@ -682,3 +726,35 @@ ValueCategory FPUnaryOp(OpBuilder &Builder, Location Loc, Value Val) {
682726
ValueCategory ValueCategory::FNeg(OpBuilder &Builder, Location Loc) const {
683727
return FPUnaryOp<arith::NegFOp>(Builder, Loc, val);
684728
}
729+
730+
ValueCategory ValueCategory::Shl(OpBuilder &Builder, Location Loc, Value RHS,
731+
bool HasNUW, bool HasNSW) const {
732+
return NUWNSWBinOp<arith::ShLIOp>(Builder, Loc, val, RHS, HasNUW, HasNSW);
733+
}
734+
735+
ValueCategory ValueCategory::AShr(OpBuilder &Builder, Location Loc, Value RHS,
736+
bool IsExact) const {
737+
warnNonExactOp<arith::ShRSIOp>(IsExact);
738+
return IntBinOp<arith::ShRSIOp>(Builder, Loc, val, RHS);
739+
}
740+
741+
ValueCategory ValueCategory::LShr(OpBuilder &Builder, Location Loc, Value RHS,
742+
bool IsExact) const {
743+
warnNonExactOp<arith::ShRUIOp>(IsExact);
744+
return IntBinOp<arith::ShRUIOp>(Builder, Loc, val, RHS);
745+
}
746+
747+
ValueCategory ValueCategory::And(mlir::OpBuilder &Builder, mlir::Location Loc,
748+
mlir::Value RHS) const {
749+
return IntBinOp<arith::AndIOp>(Builder, Loc, val, RHS);
750+
}
751+
752+
ValueCategory ValueCategory::Or(mlir::OpBuilder &Builder, mlir::Location Loc,
753+
mlir::Value RHS) const {
754+
return IntBinOp<arith::OrIOp>(Builder, Loc, val, RHS);
755+
}
756+
757+
ValueCategory ValueCategory::Xor(mlir::OpBuilder &Builder, mlir::Location Loc,
758+
mlir::Value RHS) const {
759+
return IntBinOp<arith::XOrIOp>(Builder, Loc, val, RHS);
760+
}

0 commit comments

Comments
 (0)