Skip to content

Commit 02dbbfd

Browse files
authored
[SYCL-MLIR] Promote f16 operands of unary arithmetic expressions (#7448)
Also, the `__imag__` operator now returns 0 when used on a scalar. Signed-off-by: Victor Perez <victor.perez@codeplay.com>
1 parent 973c182 commit 02dbbfd

File tree

13 files changed

+481
-84
lines changed

13 files changed

+481
-84
lines changed

polygeist/tools/cgeist/Lib/CGExpr.cc

Lines changed: 91 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2113,21 +2113,21 @@ static bool isSigned(QualType Ty) {
21132113
class BinOpInfo {
21142114
public:
21152115
BinOpInfo(ValueCategory LHS, ValueCategory RHS, QualType Ty,
2116-
BinaryOperator::Opcode Opcode, const BinaryOperator *Expr)
2117-
: LHS(LHS), RHS(RHS), Ty(Ty), Opcode(Opcode), Expr(Expr) {}
2116+
BinaryOperator::Opcode Opcode, const Expr *Expr)
2117+
: LHS(LHS), RHS(RHS), Ty(Ty), Opcode(Opcode), E(Expr) {}
21182118

21192119
ValueCategory getLHS() const { return LHS; }
21202120
ValueCategory getRHS() const { return RHS; }
21212121
constexpr QualType getType() const { return Ty; }
21222122
constexpr BinaryOperator::Opcode getOpcode() const { return Opcode; }
2123-
constexpr const BinaryOperator *getExpr() const { return Expr; }
2123+
constexpr const Expr *getExpr() const { return E; }
21242124

21252125
private:
21262126
const ValueCategory LHS;
21272127
const ValueCategory RHS;
21282128
const QualType Ty; // Computation Type.
21292129
const BinaryOperator::Opcode Opcode; // Opcode of BinOp to perform
2130-
const BinaryOperator *Expr;
2130+
const Expr *E;
21312131
};
21322132

21332133
ValueCategory MLIRScanner::EmitPromoted(Expr *E, QualType PromotionType) {
@@ -2149,12 +2149,13 @@ ValueCategory MLIRScanner::EmitPromoted(Expr *E, QualType PromotionType) {
21492149
}
21502150
} else if (auto *UO = dyn_cast<UnaryOperator>(E)) {
21512151
switch (UO->getOpcode()) {
2152-
case UO_Imag:
2153-
case UO_Real:
2154-
case UO_Minus:
2155-
case UO_Plus:
2156-
mlirclang::warning() << "Default promotion for unary operation\n";
2157-
LLVM_FALLTHROUGH;
2152+
#define HANDLEUNARYOP(OP) \
2153+
case UO_##OP: \
2154+
return Visit##OP(UO, PromotionType);
2155+
2156+
#include "Expressions.def"
2157+
#undef HANDLEUNARYOP
2158+
21582159
default:
21592160
break;
21602161
}
@@ -2798,8 +2799,10 @@ ValueCategory MLIRScanner::EmitBinSub(const BinOpInfo &Info) {
27982799
const auto DiffInChars = LHS.Sub(Builder, Loc, RHS.val);
27992800

28002801
// Okay, figure out the element size.
2801-
const QualType ElementType =
2802-
Info.getExpr()->getLHS()->getType()->getPointeeType();
2802+
const QualType ElementType = cast<BinaryOperator>(Info.getExpr())
2803+
->getLHS()
2804+
->getType()
2805+
->getPointeeType();
28032806

28042807
assert(!Glob.getCGM().getContext().getAsVariableArrayType(ElementType) &&
28052808
"Not implemented yet");
@@ -2870,3 +2873,79 @@ ValueCategory MLIRScanner::EmitBinOr(const BinOpInfo &Info) {
28702873
Info.getRHS().getValue(Builder)),
28712874
/*isReference*/ false);
28722875
}
2876+
2877+
#define HANDLEUNARYOP(OP) \
2878+
ValueCategory MLIRScanner::VisitUnary##OP(UnaryOperator *E, \
2879+
QualType PromotionType) { \
2880+
LLVM_DEBUG({ \
2881+
llvm::dbgs() << "VisitUnary" #OP ": "; \
2882+
E->dump(); \
2883+
llvm::dbgs() << "\n"; \
2884+
}); \
2885+
QualType promotionTy = \
2886+
PromotionType.isNull() \
2887+
? Glob.getTypes().getPromotionType(E->getSubExpr()->getType()) \
2888+
: PromotionType; \
2889+
ValueCategory result = Visit##OP(E, promotionTy); \
2890+
if (result.val && !promotionTy.isNull()) \
2891+
result = EmitUnPromotedValue(getMLIRLocation(E->getExprLoc()), result, \
2892+
E->getType()); \
2893+
return result; \
2894+
}
2895+
#include "Expressions.def"
2896+
#undef HANDLEUNARYOP
2897+
2898+
ValueCategory MLIRScanner::VisitPlus(UnaryOperator *E, QualType PromotionType) {
2899+
if (!PromotionType.isNull())
2900+
return EmitPromotedScalarExpr(E->getSubExpr(), PromotionType);
2901+
return Visit(E->getSubExpr());
2902+
}
2903+
2904+
ValueCategory MLIRScanner::VisitMinus(UnaryOperator *E,
2905+
QualType PromotionType) {
2906+
const Location Loc = getMLIRLocation(E->getExprLoc());
2907+
ValueCategory Op;
2908+
if (!PromotionType.isNull())
2909+
Op = EmitPromotedScalarExpr(E->getSubExpr(), PromotionType);
2910+
else
2911+
Op = Visit(E->getSubExpr());
2912+
2913+
// Generate a unary FNeg for FP ops.
2914+
if (mlirclang::isFPOrFPVectorTy(Op.val.getType()))
2915+
return Op.FNeg(Builder, Loc);
2916+
2917+
// Emit unary minus with EmitBinSub so we handle overflow cases etc.
2918+
const ValueCategory Zero =
2919+
ValueCategory::getNullValue(Builder, Loc, Op.val.getType());
2920+
return EmitBinSub(
2921+
BinOpInfo{Zero, Op, E->getType(), BinaryOperator::Opcode::BO_Sub, E});
2922+
}
2923+
2924+
ValueCategory MLIRScanner::VisitImag(UnaryOperator *E, QualType PromotionType) {
2925+
Expr *Op = E->getSubExpr();
2926+
2927+
assert(!Op->getType()->isAnyComplexType() && "Unsupported");
2928+
2929+
// __imag on a scalar returns zero. Emit the subexpr to ensure side
2930+
// effects are evaluated, but not the actual value.
2931+
if (Op->isGLValue())
2932+
EmitLValue(Op);
2933+
else if (!PromotionType.isNull())
2934+
EmitPromotedScalarExpr(Op, PromotionType);
2935+
else
2936+
Visit(Op);
2937+
auto ResTy = Glob.getTypes().getMLIRType(
2938+
!PromotionType.isNull() ? PromotionType : E->getType());
2939+
return ValueCategory::getNullValue(Builder, getMLIRLocation(E->getExprLoc()),
2940+
ResTy);
2941+
}
2942+
2943+
ValueCategory MLIRScanner::VisitReal(UnaryOperator *E, QualType PromotionType) {
2944+
Expr *Op = E->getSubExpr();
2945+
2946+
assert(!Op->getType()->isAnyComplexType() && "Unsupported");
2947+
2948+
if (!PromotionType.isNull())
2949+
return EmitPromotedScalarExpr(Op, PromotionType);
2950+
return Visit(Op);
2951+
}
Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
#ifndef HANDLEBINOP
2-
#define HANDLEBINOP(X)
3-
#endif
4-
1+
#ifdef HANDLEBINOP
52
HANDLEBINOP(Mul)
63
HANDLEBINOP(Div)
74
HANDLEBINOP(Rem)
@@ -12,3 +9,11 @@ HANDLEBINOP(Shr)
129
HANDLEBINOP(And)
1310
HANDLEBINOP(Xor)
1411
HANDLEBINOP(Or)
12+
#endif
13+
14+
#ifdef HANDLEUNARYOP
15+
HANDLEUNARYOP(Imag)
16+
HANDLEUNARYOP(Real)
17+
HANDLEUNARYOP(Minus)
18+
HANDLEUNARYOP(Plus)
19+
#endif

polygeist/tools/cgeist/Lib/ValueCategory.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Func/IR/FuncOps.h"
1919
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2020
#include "mlir/Dialect/MemRef/IR/MemRef.h"
21+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2122
#include "mlir/IR/OpDefinition.h"
2223
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
2324

@@ -57,6 +58,31 @@ mlir::Value ValueCategory::getValue(mlir::OpBuilder &builder) const {
5758
llvm_unreachable("type must be LLVMPointer or MemRef");
5859
}
5960

61+
ValueCategory ValueCategory::getNullValue(OpBuilder &Builder, Location Loc,
62+
Type Type) {
63+
const auto ZeroVal =
64+
llvm::TypeSwitch<mlir::Type, mlir::Value>(Type)
65+
.Case<mlir::IntegerType>([&](auto Ty) {
66+
return Builder.createOrFold<arith::ConstantIntOp>(Loc, 0, Ty);
67+
})
68+
.Case<mlir::IndexType>([&](auto) {
69+
return Builder.createOrFold<arith::ConstantIndexOp>(Loc, 0);
70+
})
71+
.Case<mlir::FloatType>([&](auto Ty) {
72+
return Builder.createOrFold<arith::ConstantFloatOp>(
73+
Loc, llvm::APFloat::getZero(Ty.getFloatSemantics()), Ty);
74+
})
75+
.Case<mlir::VectorType>([&](auto VecTy) {
76+
const auto Element = ValueCategory::getNullValue(
77+
Builder, Loc, VecTy.getElementType())
78+
.val;
79+
return Builder.createOrFold<vector::SplatOp>(Loc, Element, Type);
80+
})
81+
.Default(
82+
[](auto) -> mlir::Value { llvm_unreachable("Invalid type"); });
83+
return {ZeroVal, false};
84+
}
85+
6086
void ValueCategory::store(mlir::OpBuilder &builder, mlir::Value toStore) const {
6187
assert(isReference && "must be a reference");
6288
assert(val && "expect not-null");
@@ -644,3 +670,15 @@ ValueCategory ValueCategory::InBoundsGEPOrSubIndex(OpBuilder &Builder,
644670
ValueRange IdxList) const {
645671
return GEPOrSubIndex(Builder, Loc, Type, IdxList, /*IsInBounds*/ true);
646672
}
673+
674+
template <typename OpTy>
675+
ValueCategory FPUnaryOp(OpBuilder &Builder, Location Loc, Value Val) {
676+
assert(mlirclang::isFPOrFPVectorTy(Val.getType()) &&
677+
"Expecting FP or FP vector operand type");
678+
warnUnconstrainedOp<arith::NegFOp>();
679+
return {Builder.createOrFold<OpTy>(Loc, Val), false};
680+
}
681+
682+
ValueCategory ValueCategory::FNeg(OpBuilder &Builder, Location Loc) const {
683+
return FPUnaryOp<arith::NegFOp>(Builder, Loc, val);
684+
}

polygeist/tools/cgeist/Lib/ValueCategory.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ class ValueCategory {
4040
ValueCategory(std::nullptr_t) : val(nullptr), isReference(false) {}
4141
ValueCategory(mlir::Value val, bool isReference);
4242

43+
static ValueCategory getNullValue(mlir::OpBuilder &Builder,
44+
mlir::Location Loc, mlir::Type Type);
45+
4346
// TODO: rename to 'loadVariable'? getValue seems to generic.
4447
mlir::Value getValue(mlir::OpBuilder &Builder) const;
4548
void store(mlir::OpBuilder &Builder, ValueCategory toStore,
@@ -102,6 +105,7 @@ class ValueCategory {
102105
mlir::Value RHS, bool IsExact = false) const;
103106
ValueCategory ExactSDiv(mlir::OpBuilder &Builder, mlir::Location Loc,
104107
mlir::Value RHS) const;
108+
ValueCategory FNeg(mlir::OpBuilder &Builder, mlir::Location Loc) const;
105109
ValueCategory Neg(mlir::OpBuilder &Builder, mlir::Location Loc,
106110
bool HasNUW = false, bool HasNSW = false) const;
107111
ValueCategory Add(mlir::OpBuilder &Builder, mlir::Location Loc,

polygeist/tools/cgeist/Lib/clang-mlir.cc

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -913,36 +913,6 @@ ValueCategory MLIRScanner::VisitUnaryOperator(clang::UnaryOperator *U) {
913913
return ValueCategory(res,
914914
/*isReference*/ false);
915915
}
916-
case clang::UnaryOperator::Opcode::UO_Plus: {
917-
return sub;
918-
}
919-
case clang::UnaryOperator::Opcode::UO_Minus: {
920-
Value val = sub.getValue(Builder);
921-
auto ty = val.getType();
922-
if (auto ft = ty.dyn_cast<FloatType>()) {
923-
if (auto CI = val.getDefiningOp<arith::ConstantFloatOp>()) {
924-
auto api = CI.getValue().cast<FloatAttr>().getValue();
925-
return ValueCategory(Builder.create<arith::ConstantOp>(
926-
Loc, ty, FloatAttr::get(ty, -api)),
927-
/*isReference*/ false);
928-
}
929-
return ValueCategory(Builder.create<arith::NegFOp>(Loc, val),
930-
/*isReference*/ false);
931-
} else {
932-
if (auto CI = val.getDefiningOp<arith::ConstantIntOp>()) {
933-
auto api = CI.getValue().cast<IntegerAttr>().getValue();
934-
return ValueCategory(Builder.create<arith::ConstantOp>(
935-
Loc, ty, IntegerAttr::get(ty, -api)),
936-
/*isReference*/ false);
937-
}
938-
return ValueCategory(
939-
Builder.create<arith::SubIOp>(Loc,
940-
Builder.create<arith::ConstantIntOp>(
941-
Loc, 0, ty.cast<IntegerType>()),
942-
val),
943-
/*isReference*/ false);
944-
}
945-
}
946916
case clang::UnaryOperator::Opcode::UO_PreInc:
947917
case clang::UnaryOperator::Opcode::UO_PostInc: {
948918
assert(sub.isReference);
@@ -1036,41 +1006,6 @@ ValueCategory MLIRScanner::VisitUnaryOperator(clang::UnaryOperator *U) {
10361006
: next,
10371007
/*isReference*/ false);
10381008
}
1039-
case clang::UnaryOperator::Opcode::UO_Real:
1040-
case clang::UnaryOperator::Opcode::UO_Imag: {
1041-
int fnum =
1042-
(U->getOpcode() == clang::UnaryOperator::Opcode::UO_Real) ? 0 : 1;
1043-
auto lhs_v = sub.val;
1044-
assert(sub.isReference);
1045-
if (auto MT = lhs_v.getType().dyn_cast<MemRefType>()) {
1046-
auto shape = std::vector<int64_t>(MT.getShape());
1047-
shape[0] = ShapedType::kDynamicSize;
1048-
auto MT0 =
1049-
MemRefType::get(shape, MT.getElementType(),
1050-
MemRefLayoutAttrInterface(), MT.getMemorySpace());
1051-
return ValueCategory(Builder.create<polygeist::SubIndexOp>(
1052-
Loc, MT0, lhs_v, getConstantIndex(fnum)),
1053-
/*isReference*/ true);
1054-
} else if (auto PT = lhs_v.getType().dyn_cast<LLVM::LLVMPointerType>()) {
1055-
Type ET;
1056-
if (auto ST = PT.getElementType().dyn_cast<LLVM::LLVMStructType>()) {
1057-
ET = ST.getBody()[fnum];
1058-
} else {
1059-
ET = PT.getElementType().cast<LLVM::LLVMArrayType>().getElementType();
1060-
}
1061-
Value vec[2] = {Builder.create<arith::ConstantIntOp>(Loc, 0, 32),
1062-
Builder.create<arith::ConstantIntOp>(Loc, fnum, 32)};
1063-
return ValueCategory(
1064-
Builder.create<LLVM::GEPOp>(
1065-
Loc, LLVM::LLVMPointerType::get(ET, PT.getAddressSpace()), lhs_v,
1066-
vec),
1067-
/*isReference*/ true);
1068-
}
1069-
1070-
llvm::errs() << "lhs_v: " << lhs_v << "\n";
1071-
U->dump();
1072-
assert(0 && "unhandled real");
1073-
}
10741009
default: {
10751010
U->dump();
10761011
assert(0 && "unhandled opcode");

polygeist/tools/cgeist/Lib/clang-mlir.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,14 @@ class MLIRScanner : public clang::StmtVisitor<MLIRScanner, ValueCategory> {
523523
#include "Expressions.def"
524524
#undef HANDLEBINOP
525525

526+
#define HANDLEUNARYOP(OP) \
527+
ValueCategory VisitUnary##OP(clang::UnaryOperator *E, \
528+
clang::QualType PromotionTy = \
529+
clang::QualType()); \
530+
ValueCategory Visit##OP(clang::UnaryOperator *E, clang::QualType PromotionTy);
531+
#include "Expressions.def"
532+
#undef HANDLEUNARYOP
533+
526534
ValueCategory VisitCXXNoexceptExpr(clang::CXXNoexceptExpr *AS);
527535

528536
ValueCategory VisitAttributedStmt(clang::AttributedStmt *AS);

polygeist/tools/cgeist/Test/Verification/float16.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ _Float16 type(_Float16 arg) {
2323
// CHECK-EXTEND-NEXT: %[[EXT1:.*]] = arith.extf %arg1 : f16 to f32
2424
// CHECK-EXTEND-NEXT: %[[ADD:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
2525
// CHECK-EXTEND-NEXT: %[[EXT2:.*]] = arith.extf %arg2 : f16 to f32
26-
// CHECK-EXTEND-NEXT: %[[NEG:.*]] = arith.negf %arg3 : f16
27-
// CHECK-EXTEND-NEXT: %[[EXTNEG:.*]] = arith.extf %[[NEG]] : f16 to f32
28-
// CHECK-EXTEND-NEXT: %[[MUL:.*]] = arith.mulf %[[EXT2]], %[[EXTNEG]] : f32
26+
// CHECK-EXTEND-NEXT: %[[EXT3:.*]] = arith.extf %arg3 : f16 to f32
27+
// CHECK-EXTEND-NEXT: %[[NEG:.*]] = arith.negf %[[EXT3]] : f32
28+
// CHECK-EXTEND-NEXT: %[[MUL:.*]] = arith.mulf %[[EXT2]], %[[NEG]] : f32
2929
// CHECK-EXTEND-NEXT: %[[EXT4:.*]] = arith.extf %arg4 : f16 to f32
3030
// CHECK-EXTEND-NEXT: %[[DIV:.*]] = arith.divf %[[MUL]], %[[EXT4]] : f32
3131
// CHECK-EXTEND-NEXT: %[[SUB:.*]] = arith.subf %[[ADD]], %[[DIV]] : f32
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: cgeist %s --function=* -S | FileCheck %s
2+
3+
#include <complex.h>
4+
5+
// CHECK-LABEL: func.func @f0(
6+
// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32
7+
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i32
8+
// CHECK: return %[[VAL_1]] : i32
9+
// CHECK: }
10+
11+
int f0(int a) {
12+
return __imag__(a);
13+
}
14+
15+
// CHECK-LABEL: func.func @f1(
16+
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32
17+
// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
18+
// CHECK: return %[[VAL_1]] : f32
19+
// CHECK: }
20+
21+
float f1(float a) {
22+
return __imag__(a);
23+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: cgeist %s --function=* -S | FileCheck %s
2+
// XFAIL: *
3+
4+
#include <complex.h>
5+
6+
int f2(int complex a) {
7+
return __imag__(a);
8+
}
9+
10+
float f3(float complex a) {
11+
return __imag__(a);
12+
}

0 commit comments

Comments
 (0)