Skip to content

Commit 8e7ed2d

Browse files
gitoleglanza
authored andcommitted
[CIR][CIRGen][Bugfix] Fixes switch-case sub statements (#232)
This PR fixes CIR generation for the `switch-case` cases like the following: ``` case 'a': default: ... ``` or ``` default: case 'a': ... ``` i.e. when the `default` clause is sub-statement of the `case` one and vice versa.
1 parent c4b0daf commit 8e7ed2d

File tree

3 files changed

+209
-67
lines changed

3 files changed

+209
-67
lines changed

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,8 @@ class CIRGenFunction : public CIRGenTypeCache {
577577
const CIRGenFunctionInfo *CurFnInfo;
578578
clang::QualType FnRetTy;
579579

580-
/// This is the current function or global initializer that is generated code for.
580+
/// This is the current function or global initializer that is generated code
581+
/// for.
581582
mlir::Operation *CurFn = nullptr;
582583

583584
/// Save Parameter Decl for coroutine.
@@ -593,7 +594,7 @@ class CIRGenFunction : public CIRGenTypeCache {
593594

594595
CIRGenModule &getCIRGenModule() { return CGM; }
595596

596-
mlir::Block* getCurFunctionEntryBlock() {
597+
mlir::Block *getCurFunctionEntryBlock() {
597598
auto Fn = dyn_cast<mlir::cir::FuncOp>(CurFn);
598599
assert(Fn && "other callables NYI");
599600
return &Fn.getRegion().front();
@@ -1120,13 +1121,26 @@ class CIRGenFunction : public CIRGenTypeCache {
11201121

11211122
mlir::Type getCIRType(const clang::QualType &type);
11221123

1124+
const CaseStmt *foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,
1125+
SmallVector<mlir::Attribute, 4> &caseAttrs);
1126+
1127+
void insertFallthrough(const clang::Stmt &S);
1128+
1129+
template <typename T>
1130+
mlir::LogicalResult
1131+
buildCaseDefaultCascade(const T *stmt, mlir::Type condType,
1132+
SmallVector<mlir::Attribute, 4> &caseAttrs,
1133+
mlir::OperationState &os);
1134+
11231135
mlir::LogicalResult buildCaseStmt(const clang::CaseStmt &S,
11241136
mlir::Type condType,
1125-
mlir::cir::CaseAttr &caseEntry);
1137+
SmallVector<mlir::Attribute, 4> &caseAttrs,
1138+
mlir::OperationState &op);
11261139

1127-
mlir::LogicalResult buildDefaultStmt(const clang::DefaultStmt &S,
1128-
mlir::Type condType,
1129-
mlir::cir::CaseAttr &caseEntry);
1140+
mlir::LogicalResult
1141+
buildDefaultStmt(const clang::DefaultStmt &S, mlir::Type condType,
1142+
SmallVector<mlir::Attribute, 4> &caseAttrs,
1143+
mlir::OperationState &op);
11301144

11311145
mlir::cir::FuncOp generateCode(clang::GlobalDecl GD, mlir::cir::FuncOp Fn,
11321146
const CIRGenFunctionInfo &FnInfo);

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

Lines changed: 90 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -555,63 +555,102 @@ mlir::LogicalResult CIRGenFunction::buildBreakStmt(const clang::BreakStmt &S) {
555555
return mlir::success();
556556
}
557557

558-
mlir::LogicalResult CIRGenFunction::buildCaseStmt(const CaseStmt &S,
559-
mlir::Type condType,
560-
CaseAttr &caseEntry) {
561-
assert((!S.getRHS() || !S.caseStmtIsGNURange()) &&
562-
"case ranges not implemented");
563-
auto res = mlir::success();
564-
558+
const CaseStmt *
559+
CIRGenFunction::foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,
560+
SmallVector<mlir::Attribute, 4> &caseAttrs) {
565561
const CaseStmt *caseStmt = &S;
562+
const CaseStmt *lastCase = &S;
566563
SmallVector<mlir::Attribute, 4> caseEltValueListAttr;
564+
567565
// Fold cascading cases whenever possible to simplify codegen a bit.
568-
while (true) {
566+
while (caseStmt) {
567+
lastCase = caseStmt;
569568
auto intVal = caseStmt->getLHS()->EvaluateKnownConstInt(getContext());
570569
caseEltValueListAttr.push_back(mlir::cir::IntAttr::get(condType, intVal));
571-
if (isa<CaseStmt>(caseStmt->getSubStmt()))
572-
caseStmt = dyn_cast_or_null<CaseStmt>(caseStmt->getSubStmt());
573-
else
574-
break;
570+
caseStmt = dyn_cast_or_null<CaseStmt>(caseStmt->getSubStmt());
575571
}
576572

577-
auto caseValueList = builder.getArrayAttr(caseEltValueListAttr);
573+
auto *ctxt = builder.getContext();
578574

579-
auto *ctx = builder.getContext();
580-
caseEntry = mlir::cir::CaseAttr::get(
581-
ctx, caseValueList,
582-
CaseOpKindAttr::get(ctx, caseEltValueListAttr.size() > 1
583-
? mlir::cir::CaseOpKind::Anyof
584-
: mlir::cir::CaseOpKind::Equal));
575+
auto caseAttr = mlir::cir::CaseAttr::get(
576+
ctxt, builder.getArrayAttr(caseEltValueListAttr),
577+
CaseOpKindAttr::get(ctxt, caseEltValueListAttr.size() > 1
578+
? mlir::cir::CaseOpKind::Anyof
579+
: mlir::cir::CaseOpKind::Equal));
585580

586-
{
587-
mlir::OpBuilder::InsertionGuard guardCase(builder);
588-
res = buildStmt(
589-
caseStmt->getSubStmt(),
590-
/*useCurrentScope=*/!isa<CompoundStmt>(caseStmt->getSubStmt()));
591-
}
581+
caseAttrs.push_back(caseAttr);
592582

593-
// TODO: likelihood
594-
return res;
583+
return lastCase;
595584
}
596585

597-
mlir::LogicalResult CIRGenFunction::buildDefaultStmt(const DefaultStmt &S,
598-
mlir::Type condType,
599-
CaseAttr &caseEntry) {
586+
void CIRGenFunction::insertFallthrough(const clang::Stmt &S) {
587+
builder.create<YieldOp>(
588+
getLoc(S.getBeginLoc()),
589+
mlir::cir::YieldOpKindAttr::get(builder.getContext(),
590+
mlir::cir::YieldOpKind::Fallthrough),
591+
mlir::ValueRange({}));
592+
}
593+
594+
template <typename T>
595+
mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(
596+
const T *stmt, mlir::Type condType,
597+
SmallVector<mlir::Attribute, 4> &caseAttrs, mlir::OperationState &os) {
598+
599+
assert((isa<CaseStmt, DefaultStmt>(stmt)) &&
600+
"only case or default stmt go here");
601+
600602
auto res = mlir::success();
601-
auto *ctx = builder.getContext();
602-
caseEntry = mlir::cir::CaseAttr::get(
603-
ctx, builder.getArrayAttr({}),
604-
CaseOpKindAttr::get(ctx, mlir::cir::CaseOpKind::Default));
605-
{
603+
604+
// Update scope information with the current region we are
605+
// emitting code for. This is useful to allow return blocks to be
606+
// automatically and properly placed during cleanup.
607+
auto *region = os.addRegion();
608+
auto *block = builder.createBlock(region);
609+
builder.setInsertionPointToEnd(block);
610+
currLexScope->updateCurrentSwitchCaseRegion();
611+
612+
auto *sub = stmt->getSubStmt();
613+
614+
if (isa<DefaultStmt>(sub) && isa<CaseStmt>(stmt)) {
615+
insertFallthrough(*stmt);
616+
res =
617+
buildDefaultStmt(*dyn_cast<DefaultStmt>(sub), condType, caseAttrs, os);
618+
} else if (isa<CaseStmt>(sub) && isa<DefaultStmt>(stmt)) {
619+
insertFallthrough(*stmt);
620+
res = buildCaseStmt(*dyn_cast<CaseStmt>(sub), condType, caseAttrs, os);
621+
} else {
606622
mlir::OpBuilder::InsertionGuard guardCase(builder);
607-
res = buildStmt(S.getSubStmt(),
608-
/*useCurrentScope=*/!isa<CompoundStmt>(S.getSubStmt()));
623+
res = buildStmt(sub, /*useCurrentScope=*/!isa<CompoundStmt>(sub));
609624
}
610625

611-
// TODO: likelihood
612626
return res;
613627
}
614628

629+
mlir::LogicalResult
630+
CIRGenFunction::buildCaseStmt(const CaseStmt &S, mlir::Type condType,
631+
SmallVector<mlir::Attribute, 4> &caseAttrs,
632+
mlir::OperationState &os) {
633+
assert((!S.getRHS() || !S.caseStmtIsGNURange()) &&
634+
"case ranges not implemented");
635+
636+
auto *caseStmt = foldCaseStmt(S, condType, caseAttrs);
637+
return buildCaseDefaultCascade(caseStmt, condType, caseAttrs, os);
638+
}
639+
640+
mlir::LogicalResult
641+
CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType,
642+
SmallVector<mlir::Attribute, 4> &caseAttrs,
643+
mlir::OperationState &os) {
644+
auto ctxt = builder.getContext();
645+
646+
auto defAttr = mlir::cir::CaseAttr::get(
647+
ctxt, builder.getArrayAttr({}),
648+
CaseOpKindAttr::get(ctxt, mlir::cir::CaseOpKind::Default));
649+
650+
caseAttrs.push_back(defAttr);
651+
return buildCaseDefaultCascade(&S, condType, caseAttrs, os);
652+
}
653+
615654
static mlir::LogicalResult buildLoopCondYield(mlir::OpBuilder &builder,
616655
mlir::Location loc,
617656
mlir::Value cond) {
@@ -956,29 +995,20 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
956995
}
957996

958997
auto *caseStmt = dyn_cast<CaseStmt>(c);
959-
CaseAttr caseAttr;
960-
{
961-
mlir::OpBuilder::InsertionGuard guardCase(builder);
962998

963-
// Update scope information with the current region we are
964-
// emitting code for. This is useful to allow return blocks to be
965-
// automatically and properly placed during cleanup.
966-
mlir::Region *caseRegion = os.addRegion();
967-
currLexScope->updateCurrentSwitchCaseRegion();
968-
969-
lastCaseBlock = builder.createBlock(caseRegion);
970-
if (caseStmt)
971-
res = buildCaseStmt(*caseStmt, condV.getType(), caseAttr);
972-
else {
973-
auto *defaultStmt = dyn_cast<DefaultStmt>(c);
974-
assert(defaultStmt && "expected default stmt");
975-
res = buildDefaultStmt(*defaultStmt, condV.getType(), caseAttr);
976-
}
977-
978-
if (res.failed())
979-
break;
999+
if (caseStmt)
1000+
res = buildCaseStmt(*caseStmt, condV.getType(), caseAttrs, os);
1001+
else {
1002+
auto *defaultStmt = dyn_cast<DefaultStmt>(c);
1003+
assert(defaultStmt && "expected default stmt");
1004+
res = buildDefaultStmt(*defaultStmt, condV.getType(), caseAttrs,
1005+
os);
9801006
}
981-
caseAttrs.push_back(caseAttr);
1007+
1008+
lastCaseBlock = builder.getBlock();
1009+
1010+
if (res.failed())
1011+
break;
9821012
}
9831013

9841014
os.addAttribute("cases", builder.getArrayAttr(caseAttrs));
@@ -1055,4 +1085,4 @@ void CIRGenFunction::buildReturnOfRValue(mlir::Location loc, RValue RV,
10551085
llvm_unreachable("NYI");
10561086
}
10571087
buildBranchThroughCleanup(loc, ReturnBlock());
1058-
}
1088+
}

clang/test/CIR/CodeGen/switch.cpp

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ void sw1(int a) {
1515
}
1616
}
1717
}
18-
1918
// CHECK: cir.func @_Z3sw1i
2019
// CHECK: cir.switch (%3 : !s32i) [
2120
// CHECK-NEXT: case (equal, 0) {
@@ -160,3 +159,102 @@ void sw7(int a) {
160159
// CHECK-NEXT: case (anyof, [3, 4, 5] : !s32i) {
161160
// CHECK-NEXT: cir.yield break
162161
// CHECK-NEXT: }
162+
163+
void sw8(int a) {
164+
switch (a)
165+
{
166+
case 3:
167+
break;
168+
case 4:
169+
default:
170+
break;
171+
}
172+
}
173+
174+
//CHECK: cir.func @_Z3sw8i
175+
//CHECK: case (equal, 3)
176+
//CHECK-NEXT: cir.yield break
177+
//CHECK-NEXT: },
178+
//CHECK-NEXT: case (equal, 4) {
179+
//CHECK-NEXT: cir.yield fallthrough
180+
//CHECK-NEXT: }
181+
//CHECK-NEXT: case (default) {
182+
//CHECK-NEXT: cir.yield break
183+
//CHECK-NEXT: }
184+
185+
void sw9(int a) {
186+
switch (a)
187+
{
188+
case 3:
189+
break;
190+
default:
191+
case 4:
192+
break;
193+
}
194+
}
195+
196+
//CHECK: cir.func @_Z3sw9i
197+
//CHECK: case (equal, 3) {
198+
//CHECK-NEXT: cir.yield break
199+
//CHECK-NEXT: }
200+
//CHECK-NEXT: case (default) {
201+
//CHECK-NEXT: cir.yield fallthrough
202+
//CHECK-NEXT: }
203+
//CHECK: case (equal, 4)
204+
//CHECK-NEXT: cir.yield break
205+
//CHECK-NEXT: }
206+
207+
void sw10(int a) {
208+
switch (a)
209+
{
210+
case 3:
211+
break;
212+
case 4:
213+
default:
214+
case 5:
215+
break;
216+
}
217+
}
218+
219+
//CHECK: cir.func @_Z4sw10i
220+
//CHECK: case (equal, 3)
221+
//CHECK-NEXT: cir.yield break
222+
//CHECK-NEXT: },
223+
//CHECK-NEXT: case (equal, 4) {
224+
//CHECK-NEXT: cir.yield fallthrough
225+
//CHECK-NEXT: }
226+
//CHECK-NEXT: case (default) {
227+
//CHECK-NEXT: cir.yield fallthrough
228+
//CHECK-NEXT: }
229+
//CHECK-NEXT: case (equal, 5) {
230+
//CHECK-NEXT: cir.yield break
231+
//CHECK-NEXT: }
232+
233+
void sw11(int a) {
234+
switch (a)
235+
{
236+
case 3:
237+
break;
238+
case 4:
239+
case 5:
240+
default:
241+
case 6:
242+
case 7:
243+
break;
244+
}
245+
}
246+
247+
//CHECK: cir.func @_Z4sw11i
248+
//CHECK: case (equal, 3)
249+
//CHECK-NEXT: cir.yield break
250+
//CHECK-NEXT: },
251+
//CHECK-NEXT: case (anyof, [4, 5] : !s32i) {
252+
//CHECK-NEXT: cir.yield fallthrough
253+
//CHECK-NEXT: }
254+
//CHECK-NEXT: case (default) {
255+
//CHECK-NEXT: cir.yield fallthrough
256+
//CHECK-NEXT: }
257+
//CHECK-NEXT: case (anyof, [6, 7] : !s32i) {
258+
//CHECK-NEXT: cir.yield break
259+
//CHECK-NEXT: }
260+

0 commit comments

Comments
 (0)