Skip to content

Commit 9a9c96d

Browse files
committed
Update on "[CIR] Add reconcile unrealized casts pass"
[ghstack-poisoned]
2 parents 15be706 + 2d83c0b commit 9a9c96d

File tree

21 files changed

+310
-87
lines changed

21 files changed

+310
-87
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,9 @@ def StructElementAddr : CIR_Op<"struct_element_addr"> {
13771377
The `cir.struct_element_addr` operaration gets the address of a particular
13781378
named member from the input struct.
13791379

1380+
It expects a pointer to the base struct as well as the name of the member
1381+
and its field index.
1382+
13801383
Example:
13811384
```mlir
13821385
!ty_22struct2EBar22 = type !cir.struct<"struct.Bar", i32, i8>
@@ -1391,10 +1394,25 @@ def StructElementAddr : CIR_Op<"struct_element_addr"> {
13911394

13921395
let arguments = (ins
13931396
Arg<CIR_PointerType, "the address to load from", [MemRead]>:$struct_addr,
1394-
StrAttr:$member_name);
1397+
StrAttr:$member_name,
1398+
IndexAttr:$member_index);
13951399

13961400
let results = (outs Res<CIR_PointerType, "">:$result);
13971401

1402+
let builders = [
1403+
OpBuilder<(ins "Type":$type, "Value":$value, "llvm::StringRef":$name,
1404+
"unsigned":$index),
1405+
[{
1406+
mlir::APInt fieldIdx(64, index);
1407+
build($_builder, $_state, type, value, name, fieldIdx);
1408+
}]>
1409+
];
1410+
1411+
let extraClassDeclaration = [{
1412+
/// Return the index of the struct member being accessed.
1413+
uint64_t getIndex() { return getMemberIndex().getZExtValue(); }
1414+
}];
1415+
13981416
// FIXME: add verifier.
13991417
}
14001418

clang/lib/CIR/CodeGen/CIRGenClass.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ static void buildLValueForAnyFieldInitialization(CIRGenFunction &CGF,
207207
if (MemberInit->isIndirectMemberInitializer()) {
208208
llvm_unreachable("NYI");
209209
} else {
210-
LHS = CGF.buildLValueForFieldInitialization(LHS, Field, Field->getName());
210+
LHS = CGF.buildLValueForFieldInitialization(LHS, Field, Field->getName(),
211+
Field->getFieldIndex());
211212
}
212213
}
213214

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
#include "CIRGenCstEmitter.h"
1717
#include "CIRGenFunction.h"
1818
#include "CIRGenModule.h"
19+
#include "CIRGenValue.h"
1920
#include "UnimplementedFeatureGuarding.h"
2021

2122
#include "clang/AST/GlobalDecl.h"
2223
#include "clang/Basic/Builtins.h"
2324
#include "clang/CIR/Dialect/IR/CIRDialect.h"
2425
#include "clang/CIR/Dialect/IR/CIRTypes.h"
2526
#include "llvm/Support/Casting.h"
27+
#include "llvm/Support/ErrorHandling.h"
2628

2729
#include "mlir/Dialect/Func/IR/FuncOps.h"
2830
#include "mlir/IR/Value.h"
@@ -50,7 +52,8 @@ static Address buildPreserveStructAccess(CIRGenFunction &CGF, LValue base,
5052
/// doesn't necessarily have the right type.
5153
static Address buildAddrOfFieldStorage(CIRGenFunction &CGF, Address Base,
5254
const FieldDecl *field,
53-
llvm::StringRef fieldName) {
55+
llvm::StringRef fieldName,
56+
unsigned fieldIndex) {
5457
if (field->isZeroSize(CGF.getContext()))
5558
llvm_unreachable("NYI");
5659

@@ -63,7 +66,7 @@ static Address buildAddrOfFieldStorage(CIRGenFunction &CGF, Address Base,
6366
// which do not currently carry the name, so it can be passed down from the
6467
// CaptureStmt.
6568
auto sea = CGF.getBuilder().create<mlir::cir::StructElementAddr>(
66-
loc, fieldPtr, Base.getPointer(), fieldName);
69+
loc, fieldPtr, Base.getPointer(), fieldName, fieldIndex);
6770

6871
// TODO: We could get the alignment from the CIRGenRecordLayout, but given the
6972
// member name based lookup of the member here we probably shouldn't be. We'll
@@ -235,9 +238,10 @@ LValue CIRGenFunction::buildLValueForField(LValue base,
235238
if (!IsInPreservedAIRegion &&
236239
(!getDebugInfo() || !rec->hasAttr<BPFPreserveAccessIndexAttr>())) {
237240
llvm::StringRef fieldName = field->getName();
241+
unsigned fieldIndex = field->getFieldIndex();
238242
if (CGM.LambdaFieldToName.count(field))
239243
fieldName = CGM.LambdaFieldToName[field];
240-
addr = buildAddrOfFieldStorage(*this, addr, field, fieldName);
244+
addr = buildAddrOfFieldStorage(*this, addr, field, fieldName, fieldIndex);
241245
} else
242246
// Remember the original struct field index
243247
addr = buildPreserveStructAccess(*this, base, addr, field);
@@ -281,14 +285,15 @@ LValue CIRGenFunction::buildLValueForField(LValue base,
281285
}
282286

283287
LValue CIRGenFunction::buildLValueForFieldInitialization(
284-
LValue Base, const clang::FieldDecl *Field, llvm::StringRef FieldName) {
288+
LValue Base, const clang::FieldDecl *Field, llvm::StringRef FieldName,
289+
unsigned FieldIndex) {
285290
QualType FieldType = Field->getType();
286291

287292
if (!FieldType->isReferenceType())
288293
return buildLValueForField(Base, Field);
289294

290-
Address V =
291-
buildAddrOfFieldStorage(*this, Base.getAddress(), Field, FieldName);
295+
Address V = buildAddrOfFieldStorage(*this, Base.getAddress(), Field,
296+
FieldName, FieldIndex);
292297

293298
// Make sure that the address is pointing to the right type.
294299
auto memTy = getTypes().convertTypeForMem(FieldType);
@@ -577,9 +582,9 @@ LValue CIRGenFunction::buildDeclRefLValue(const DeclRefExpr *E) {
577582

578583
if (const auto *VD = dyn_cast<VarDecl>(ND)) {
579584
// Global Named registers access via intrinsics only
580-
if (VD->getStorageClass() == SC_Register &&
581-
VD->hasAttr<AsmLabelAttr>() && !VD->isLocalVarDecl())
582-
llvm_unreachable("NYI");
585+
if (VD->getStorageClass() == SC_Register && VD->hasAttr<AsmLabelAttr>() &&
586+
!VD->isLocalVarDecl())
587+
llvm_unreachable("NYI");
583588

584589
assert(E->isNonOdrUse() != NOUR_Constant && "not implemented");
585590

@@ -1172,12 +1177,11 @@ LValue CIRGenFunction::buildArraySubscriptExpr(const ArraySubscriptExpr *E,
11721177
bool Accessed) {
11731178
// The index must always be an integer, which is not an aggregate. Emit it
11741179
// in lexical order (this complexity is, sadly, required by C++17).
1175-
// llvm::Value *IdxPre =
1176-
// (E->getLHS() == E->getIdx()) ? EmitScalarExpr(E->getIdx()) : nullptr;
1177-
assert(E->getLHS() != E->getIdx() && "not implemented");
1180+
mlir::Value IdxPre =
1181+
(E->getLHS() == E->getIdx()) ? buildScalarExpr(E->getIdx()) : nullptr;
11781182
bool SignedIndices = false;
1179-
auto EmitIdxAfterBase = [&](bool Promote) -> mlir::Value {
1180-
mlir::Value Idx;
1183+
auto EmitIdxAfterBase = [&, IdxPre](bool Promote) -> mlir::Value {
1184+
mlir::Value Idx = IdxPre;
11811185
if (E->getLHS() != E->getIdx()) {
11821186
assert(E->getRHS() == E->getIdx() && "index was neither LHS nor RHS");
11831187
Idx = buildScalarExpr(E->getIdx());
@@ -1187,39 +1191,41 @@ LValue CIRGenFunction::buildArraySubscriptExpr(const ArraySubscriptExpr *E,
11871191
bool IdxSigned = IdxTy->isSignedIntegerOrEnumerationType();
11881192
SignedIndices |= IdxSigned;
11891193

1190-
assert(!SanOpts.has(SanitizerKind::ArrayBounds) && "not implemented");
1194+
if (SanOpts.has(SanitizerKind::ArrayBounds))
1195+
llvm_unreachable("array bounds sanitizer is NYI");
11911196

1192-
// TODO: Extend or truncate the index type to 32 or 64-bits.
1193-
// if (Promote && !Idx.getType().isa<::mlir::cir::PointerType>()) {
1194-
// Idx = Builder.CreateIntCast(Idx, IntPtrTy, IdxSigned, "idxprom");
1195-
// }
1197+
// Extend or truncate the index type to 32 or 64-bits.
1198+
auto ptrTy = Idx.getType().dyn_cast<mlir::cir::PointerType>();
1199+
if (Promote && ptrTy && ptrTy.getPointee().isa<mlir::cir::IntType>())
1200+
llvm_unreachable("index type cast is NYI");
11961201

11971202
return Idx;
11981203
};
1204+
IdxPre = nullptr;
11991205

12001206
// If the base is a vector type, then we are forming a vector element
12011207
// with this subscript.
12021208
if (E->getBase()->getType()->isVectorType() &&
12031209
!isa<ExtVectorElementExpr>(E->getBase())) {
1204-
assert(0 && "not implemented");
1210+
llvm_unreachable("vector subscript is NYI");
12051211
}
12061212

12071213
// All the other cases basically behave like simple offsetting.
12081214

12091215
// Handle the extvector case we ignored above.
12101216
if (isa<ExtVectorElementExpr>(E->getBase())) {
1211-
assert(0 && "not implemented");
1217+
llvm_unreachable("extvector subscript is NYI");
12121218
}
12131219

1214-
// TODO: TBAAAccessInfo
1220+
assert(!UnimplementedFeature::tbaa() && "TBAA is NYI");
12151221
LValueBaseInfo EltBaseInfo;
12161222
Address Addr = Address::invalid();
12171223
if (const VariableArrayType *vla =
12181224
getContext().getAsVariableArrayType(E->getType())) {
1219-
assert(0 && "not implemented");
1225+
llvm_unreachable("variable array subscript is NYI");
12201226
} else if (const ObjCObjectType *OIT =
12211227
E->getType()->getAs<ObjCObjectType>()) {
1222-
assert(0 && "not implemented");
1228+
llvm_unreachable("ObjC object type subscript is NYI");
12231229
} else if (const Expr *Array = isSimpleArrayDecayOperand(E->getBase())) {
12241230
// If this is A[i] where A is an array, the frontend will have decayed
12251231
// the base to be a ArrayToPointerDecay implicit cast. While correct, it is
@@ -1230,26 +1236,26 @@ LValue CIRGenFunction::buildArraySubscriptExpr(const ArraySubscriptExpr *E,
12301236
LValue ArrayLV;
12311237
// For simple multidimensional array indexing, set the 'accessed' flag
12321238
// for better bounds-checking of the base expression.
1233-
// if (const auto *ASE = dyn_cast<ArraySubscriptExpr>(Array))
1234-
// ArrayLV = buildArraySubscriptExpr(ASE, /*Accessed*/ true);
1235-
assert(!llvm::isa<ArraySubscriptExpr>(Array) &&
1236-
"multidimensional array indexing not implemented");
1237-
1238-
ArrayLV = buildLValue(Array);
1239+
if (const auto *ASE = dyn_cast<ArraySubscriptExpr>(Array))
1240+
ArrayLV = buildArraySubscriptExpr(ASE, /*Accessed=*/true);
1241+
else
1242+
ArrayLV = buildLValue(Array);
12391243
auto Idx = EmitIdxAfterBase(/*Promote=*/true);
1240-
QualType arrayType = Array->getType();
12411244

12421245
// Propagate the alignment from the array itself to the result.
1246+
QualType arrayType = Array->getType();
12431247
Addr = buildArraySubscriptPtr(
12441248
*this, CGM.getLoc(Array->getBeginLoc()), CGM.getLoc(Array->getEndLoc()),
12451249
ArrayLV.getAddress(), {Idx}, E->getType(),
12461250
!getLangOpts().isSignedOverflowDefined(), SignedIndices,
12471251
CGM.getLoc(E->getExprLoc()), &arrayType, E->getBase());
12481252
EltBaseInfo = ArrayLV.getBaseInfo();
1249-
// TODO: EltTBAAInfo
1253+
// TODO(cir): EltTBAAInfo
1254+
assert(!UnimplementedFeature::tbaa() && "TBAA is NYI");
12501255
} else {
12511256
// The base must be a pointer; emit it with an estimate of its alignment.
12521257
// TODO(cir): EltTBAAInfo
1258+
assert(!UnimplementedFeature::tbaa() && "TBAA is NYI");
12531259
Addr = buildPointerWithAlignment(E->getBase(), &EltBaseInfo);
12541260
auto Idx = EmitIdxAfterBase(/*Promote*/ true);
12551261
QualType ptrType = E->getBase()->getType();
@@ -1260,9 +1266,11 @@ LValue CIRGenFunction::buildArraySubscriptExpr(const ArraySubscriptExpr *E,
12601266
}
12611267

12621268
LValue LV = LValue::makeAddr(Addr, E->getType(), EltBaseInfo);
1269+
12631270
if (getLangOpts().ObjC && getLangOpts().getGC() != LangOptions::NonGC) {
1264-
assert(0 && "not implemented");
1271+
llvm_unreachable("ObjC is NYI");
12651272
}
1273+
12661274
return LV;
12671275
}
12681276

clang/lib/CIR/CodeGen/CIRGenExprAgg.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,8 @@ void AggExprEmitter::VisitLambdaExpr(LambdaExpr *E) {
445445
}
446446

447447
// Emit initialization
448-
LValue LV =
449-
CGF.buildLValueForFieldInitialization(SlotLV, *CurField, fieldName);
448+
LValue LV = CGF.buildLValueForFieldInitialization(
449+
SlotLV, *CurField, fieldName, CurField->getFieldIndex());
450450
if (CurField->hasCapturedVLAType()) {
451451
llvm_unreachable("NYI");
452452
}
@@ -701,8 +701,8 @@ void AggExprEmitter::VisitCXXParenListOrInitListExpr(
701701
CGF.getTypes().isZeroInitializable(ExprToVisit->getType()))
702702
break;
703703

704-
LValue LV =
705-
CGF.buildLValueForFieldInitialization(DestLV, field, field->getName());
704+
LValue LV = CGF.buildLValueForFieldInitialization(
705+
DestLV, field, field->getName(), field->getFieldIndex());
706706
// We never generate write-barries for initialized fields.
707707
assert(!UnimplementedFeature::setNonGC());
708708

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class CallOp;
4444
namespace {
4545
class ScalarExprEmitter;
4646
class AggExprEmitter;
47-
}
47+
} // namespace
4848

4949
namespace cir {
5050

@@ -1448,7 +1448,8 @@ class CIRGenFunction : public CIRGenTypeCache {
14481448
/// stored in the reference.
14491449
LValue buildLValueForFieldInitialization(LValue Base,
14501450
const clang::FieldDecl *Field,
1451-
llvm::StringRef FieldName);
1451+
llvm::StringRef FieldName,
1452+
unsigned FieldIndex);
14521453

14531454
void buildInitializerForField(clang::FieldDecl *Field, LValue LHS,
14541455
clang::Expr *Init);

clang/lib/CIR/Dialect/Transforms/MergeCleanups.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "mlir/IR/Matchers.h"
1717
#include "mlir/IR/PatternMatch.h"
18+
#include "mlir/Support/LogicalResult.h"
1819
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1920

2021
using namespace mlir;
@@ -156,10 +157,17 @@ template <>
156157
mlir::LogicalResult
157158
SimplifyRetYieldBlocks<ScopeOp>::replaceScopeLikeOp(PatternRewriter &rewriter,
158159
ScopeOp scopeOp) const {
159-
auto regionChanged = mlir::failure();
160+
// Scope region empty: just remove scope.
161+
if (scopeOp.getRegion().empty()) {
162+
rewriter.eraseOp(scopeOp);
163+
return mlir::success();
164+
}
165+
166+
// Scope region non-empty: clean it up.
160167
if (checkAndRewriteRegion(scopeOp.getRegion(), rewriter).succeeded())
161-
regionChanged = mlir::success();
162-
return regionChanged;
168+
return mlir::success();
169+
170+
return mlir::failure();
163171
}
164172

165173
template <>

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,7 +1104,7 @@ class CIRBinOpLowering : public mlir::OpConversionPattern<mlir::cir::BinOp> {
11041104
if (ty.isUnsigned())
11051105
rewriter.replaceOpWithNewOp<mlir::LLVM::UDivOp>(op, llvmTy, lhs, rhs);
11061106
else
1107-
llvm_unreachable("signed integer division binop lowering NYI");
1107+
rewriter.replaceOpWithNewOp<mlir::LLVM::SDivOp>(op, llvmTy, lhs, rhs);
11081108
} else
11091109
rewriter.replaceOpWithNewOp<mlir::LLVM::FDivOp>(op, llvmTy, lhs, rhs);
11101110
break;
@@ -1113,7 +1113,7 @@ class CIRBinOpLowering : public mlir::OpConversionPattern<mlir::cir::BinOp> {
11131113
if (ty.isUnsigned())
11141114
rewriter.replaceOpWithNewOp<mlir::LLVM::URemOp>(op, llvmTy, lhs, rhs);
11151115
else
1116-
llvm_unreachable("signed integer remainder binop lowering NYI");
1116+
rewriter.replaceOpWithNewOp<mlir::LLVM::SRemOp>(op, llvmTy, lhs, rhs);
11171117
} else
11181118
rewriter.replaceOpWithNewOp<mlir::LLVM::FRemOp>(op, llvmTy, lhs, rhs);
11191119
break;
@@ -1134,7 +1134,7 @@ class CIRBinOpLowering : public mlir::OpConversionPattern<mlir::cir::BinOp> {
11341134
if (ty.isUnsigned())
11351135
rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, lhs, rhs);
11361136
else
1137-
llvm_unreachable("signed integer shift binop lowering NYI");
1137+
rewriter.replaceOpWithNewOp<mlir::LLVM::AShrOp>(op, llvmTy, lhs, rhs);
11381138
break;
11391139
}
11401140
}
@@ -1286,18 +1286,37 @@ class CIRBrOpLowering : public mlir::OpConversionPattern<mlir::cir::BrOp> {
12861286
}
12871287
};
12881288

1289+
class CIRStructElementAddrOpLowering
1290+
: public mlir::OpConversionPattern<mlir::cir::StructElementAddr> {
1291+
public:
1292+
using mlir::OpConversionPattern<
1293+
mlir::cir::StructElementAddr>::OpConversionPattern;
1294+
1295+
mlir::LogicalResult
1296+
matchAndRewrite(mlir::cir::StructElementAddr op, OpAdaptor adaptor,
1297+
mlir::ConversionPatternRewriter &rewriter) const override {
1298+
auto llResTy = getTypeConverter()->convertType(op.getType());
1299+
// Since the base address is a pointer to structs, the first offset is
1300+
// always zero. The second offset tell us which member it will access.
1301+
llvm::SmallVector<mlir::LLVM::GEPArg> offset{0, op.getIndex()};
1302+
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
1303+
op, llResTy, adaptor.getStructAddr(), offset);
1304+
return mlir::success();
1305+
}
1306+
};
1307+
12891308
void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
12901309
mlir::TypeConverter &converter) {
12911310
patterns.add<CIRReturnLowering>(patterns.getContext());
1292-
patterns.add<CIRCmpOpLowering, CIRLoopOpLowering, CIRBrCondOpLowering,
1293-
CIRPtrStrideOpLowering, CIRCallLowering, CIRUnaryOpLowering,
1294-
CIRBinOpLowering, CIRLoadLowering, CIRConstantLowering,
1295-
CIRStoreLowering, CIRAllocaLowering, CIRFuncLowering,
1296-
CIRScopeOpLowering, CIRCastOpLowering, CIRIfLowering,
1297-
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRVAStartLowering,
1298-
CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering,
1299-
CIRBrOpLowering, CIRTernaryOpLowering>(converter,
1300-
patterns.getContext());
1311+
patterns.add<
1312+
CIRCmpOpLowering, CIRLoopOpLowering, CIRBrCondOpLowering,
1313+
CIRPtrStrideOpLowering, CIRCallLowering, CIRUnaryOpLowering,
1314+
CIRBinOpLowering, CIRLoadLowering, CIRConstantLowering, CIRStoreLowering,
1315+
CIRAllocaLowering, CIRFuncLowering, CIRScopeOpLowering, CIRCastOpLowering,
1316+
CIRIfLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
1317+
CIRVAStartLowering, CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering,
1318+
CIRBrOpLowering, CIRTernaryOpLowering, CIRStructElementAddrOpLowering>(
1319+
converter, patterns.getContext());
13011320
}
13021321

13031322
namespace {

0 commit comments

Comments
 (0)