Skip to content

Commit fcd1db0

Browse files
gitoleglanza
authored andcommitted
[CIR][CIRGen][Bugfix] Fixes switch-case sub statements (#232)
This PR fixes CIR generation for the `switch-case` cases like the following: ``` case 'a': default: ... ``` or ``` default: case 'a': ... ``` i.e. when the `default` clause is sub-statement of the `case` one and vice versa.
1 parent a66193f commit fcd1db0

File tree

3 files changed

+209
-67
lines changed

3 files changed

+209
-67
lines changed

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,8 @@ class CIRGenFunction : public CIRGenTypeCache {
577577
const CIRGenFunctionInfo *CurFnInfo;
578578
clang::QualType FnRetTy;
579579

580-
/// This is the current function or global initializer that is generated code for.
580+
/// This is the current function or global initializer that is generated code
581+
/// for.
581582
mlir::Operation *CurFn = nullptr;
582583

583584
/// Save Parameter Decl for coroutine.
@@ -593,7 +594,7 @@ class CIRGenFunction : public CIRGenTypeCache {
593594

594595
CIRGenModule &getCIRGenModule() { return CGM; }
595596

596-
mlir::Block* getCurFunctionEntryBlock() {
597+
mlir::Block *getCurFunctionEntryBlock() {
597598
auto Fn = dyn_cast<mlir::cir::FuncOp>(CurFn);
598599
assert(Fn && "other callables NYI");
599600
return &Fn.getRegion().front();
@@ -1120,13 +1121,26 @@ class CIRGenFunction : public CIRGenTypeCache {
11201121

11211122
mlir::Type getCIRType(const clang::QualType &type);
11221123

1124+
const CaseStmt *foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,
1125+
SmallVector<mlir::Attribute, 4> &caseAttrs);
1126+
1127+
void insertFallthrough(const clang::Stmt &S);
1128+
1129+
template <typename T>
1130+
mlir::LogicalResult
1131+
buildCaseDefaultCascade(const T *stmt, mlir::Type condType,
1132+
SmallVector<mlir::Attribute, 4> &caseAttrs,
1133+
mlir::OperationState &os);
1134+
11231135
mlir::LogicalResult buildCaseStmt(const clang::CaseStmt &S,
11241136
mlir::Type condType,
1125-
mlir::cir::CaseAttr &caseEntry);
1137+
SmallVector<mlir::Attribute, 4> &caseAttrs,
1138+
mlir::OperationState &op);
11261139

1127-
mlir::LogicalResult buildDefaultStmt(const clang::DefaultStmt &S,
1128-
mlir::Type condType,
1129-
mlir::cir::CaseAttr &caseEntry);
1140+
mlir::LogicalResult
1141+
buildDefaultStmt(const clang::DefaultStmt &S, mlir::Type condType,
1142+
SmallVector<mlir::Attribute, 4> &caseAttrs,
1143+
mlir::OperationState &op);
11301144

11311145
mlir::cir::FuncOp generateCode(clang::GlobalDecl GD, mlir::cir::FuncOp Fn,
11321146
const CIRGenFunctionInfo &FnInfo);

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

Lines changed: 90 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -557,63 +557,102 @@ mlir::LogicalResult CIRGenFunction::buildBreakStmt(const clang::BreakStmt &S) {
557557
return mlir::success();
558558
}
559559

560-
mlir::LogicalResult CIRGenFunction::buildCaseStmt(const CaseStmt &S,
561-
mlir::Type condType,
562-
CaseAttr &caseEntry) {
563-
assert((!S.getRHS() || !S.caseStmtIsGNURange()) &&
564-
"case ranges not implemented");
565-
auto res = mlir::success();
566-
560+
const CaseStmt *
561+
CIRGenFunction::foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,
562+
SmallVector<mlir::Attribute, 4> &caseAttrs) {
567563
const CaseStmt *caseStmt = &S;
564+
const CaseStmt *lastCase = &S;
568565
SmallVector<mlir::Attribute, 4> caseEltValueListAttr;
566+
569567
// Fold cascading cases whenever possible to simplify codegen a bit.
570-
while (true) {
568+
while (caseStmt) {
569+
lastCase = caseStmt;
571570
auto intVal = caseStmt->getLHS()->EvaluateKnownConstInt(getContext());
572571
caseEltValueListAttr.push_back(mlir::cir::IntAttr::get(condType, intVal));
573-
if (isa<CaseStmt>(caseStmt->getSubStmt()))
574-
caseStmt = dyn_cast_or_null<CaseStmt>(caseStmt->getSubStmt());
575-
else
576-
break;
572+
caseStmt = dyn_cast_or_null<CaseStmt>(caseStmt->getSubStmt());
577573
}
578574

579-
auto caseValueList = builder.getArrayAttr(caseEltValueListAttr);
575+
auto *ctxt = builder.getContext();
580576

581-
auto *ctx = builder.getContext();
582-
caseEntry = mlir::cir::CaseAttr::get(
583-
ctx, caseValueList,
584-
CaseOpKindAttr::get(ctx, caseEltValueListAttr.size() > 1
585-
? mlir::cir::CaseOpKind::Anyof
586-
: mlir::cir::CaseOpKind::Equal));
577+
auto caseAttr = mlir::cir::CaseAttr::get(
578+
ctxt, builder.getArrayAttr(caseEltValueListAttr),
579+
CaseOpKindAttr::get(ctxt, caseEltValueListAttr.size() > 1
580+
? mlir::cir::CaseOpKind::Anyof
581+
: mlir::cir::CaseOpKind::Equal));
587582

588-
{
589-
mlir::OpBuilder::InsertionGuard guardCase(builder);
590-
res = buildStmt(
591-
caseStmt->getSubStmt(),
592-
/*useCurrentScope=*/!isa<CompoundStmt>(caseStmt->getSubStmt()));
593-
}
583+
caseAttrs.push_back(caseAttr);
594584

595-
// TODO: likelihood
596-
return res;
585+
return lastCase;
597586
}
598587

599-
mlir::LogicalResult CIRGenFunction::buildDefaultStmt(const DefaultStmt &S,
600-
mlir::Type condType,
601-
CaseAttr &caseEntry) {
588+
void CIRGenFunction::insertFallthrough(const clang::Stmt &S) {
589+
builder.create<YieldOp>(
590+
getLoc(S.getBeginLoc()),
591+
mlir::cir::YieldOpKindAttr::get(builder.getContext(),
592+
mlir::cir::YieldOpKind::Fallthrough),
593+
mlir::ValueRange({}));
594+
}
595+
596+
template <typename T>
597+
mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(
598+
const T *stmt, mlir::Type condType,
599+
SmallVector<mlir::Attribute, 4> &caseAttrs, mlir::OperationState &os) {
600+
601+
assert((isa<CaseStmt, DefaultStmt>(stmt)) &&
602+
"only case or default stmt go here");
603+
602604
auto res = mlir::success();
603-
auto *ctx = builder.getContext();
604-
caseEntry = mlir::cir::CaseAttr::get(
605-
ctx, builder.getArrayAttr({}),
606-
CaseOpKindAttr::get(ctx, mlir::cir::CaseOpKind::Default));
607-
{
605+
606+
// Update scope information with the current region we are
607+
// emitting code for. This is useful to allow return blocks to be
608+
// automatically and properly placed during cleanup.
609+
auto *region = os.addRegion();
610+
auto *block = builder.createBlock(region);
611+
builder.setInsertionPointToEnd(block);
612+
currLexScope->updateCurrentSwitchCaseRegion();
613+
614+
auto *sub = stmt->getSubStmt();
615+
616+
if (isa<DefaultStmt>(sub) && isa<CaseStmt>(stmt)) {
617+
insertFallthrough(*stmt);
618+
res =
619+
buildDefaultStmt(*dyn_cast<DefaultStmt>(sub), condType, caseAttrs, os);
620+
} else if (isa<CaseStmt>(sub) && isa<DefaultStmt>(stmt)) {
621+
insertFallthrough(*stmt);
622+
res = buildCaseStmt(*dyn_cast<CaseStmt>(sub), condType, caseAttrs, os);
623+
} else {
608624
mlir::OpBuilder::InsertionGuard guardCase(builder);
609-
res = buildStmt(S.getSubStmt(),
610-
/*useCurrentScope=*/!isa<CompoundStmt>(S.getSubStmt()));
625+
res = buildStmt(sub, /*useCurrentScope=*/!isa<CompoundStmt>(sub));
611626
}
612627

613-
// TODO: likelihood
614628
return res;
615629
}
616630

631+
mlir::LogicalResult
632+
CIRGenFunction::buildCaseStmt(const CaseStmt &S, mlir::Type condType,
633+
SmallVector<mlir::Attribute, 4> &caseAttrs,
634+
mlir::OperationState &os) {
635+
assert((!S.getRHS() || !S.caseStmtIsGNURange()) &&
636+
"case ranges not implemented");
637+
638+
auto *caseStmt = foldCaseStmt(S, condType, caseAttrs);
639+
return buildCaseDefaultCascade(caseStmt, condType, caseAttrs, os);
640+
}
641+
642+
mlir::LogicalResult
643+
CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType,
644+
SmallVector<mlir::Attribute, 4> &caseAttrs,
645+
mlir::OperationState &os) {
646+
auto ctxt = builder.getContext();
647+
648+
auto defAttr = mlir::cir::CaseAttr::get(
649+
ctxt, builder.getArrayAttr({}),
650+
CaseOpKindAttr::get(ctxt, mlir::cir::CaseOpKind::Default));
651+
652+
caseAttrs.push_back(defAttr);
653+
return buildCaseDefaultCascade(&S, condType, caseAttrs, os);
654+
}
655+
617656
static mlir::LogicalResult buildLoopCondYield(mlir::OpBuilder &builder,
618657
mlir::Location loc,
619658
mlir::Value cond) {
@@ -958,29 +997,20 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
958997
}
959998

960999
auto *caseStmt = dyn_cast<CaseStmt>(c);
961-
CaseAttr caseAttr;
962-
{
963-
mlir::OpBuilder::InsertionGuard guardCase(builder);
9641000

965-
// Update scope information with the current region we are
966-
// emitting code for. This is useful to allow return blocks to be
967-
// automatically and properly placed during cleanup.
968-
mlir::Region *caseRegion = os.addRegion();
969-
currLexScope->updateCurrentSwitchCaseRegion();
970-
971-
lastCaseBlock = builder.createBlock(caseRegion);
972-
if (caseStmt)
973-
res = buildCaseStmt(*caseStmt, condV.getType(), caseAttr);
974-
else {
975-
auto *defaultStmt = dyn_cast<DefaultStmt>(c);
976-
assert(defaultStmt && "expected default stmt");
977-
res = buildDefaultStmt(*defaultStmt, condV.getType(), caseAttr);
978-
}
979-
980-
if (res.failed())
981-
break;
1001+
if (caseStmt)
1002+
res = buildCaseStmt(*caseStmt, condV.getType(), caseAttrs, os);
1003+
else {
1004+
auto *defaultStmt = dyn_cast<DefaultStmt>(c);
1005+
assert(defaultStmt && "expected default stmt");
1006+
res = buildDefaultStmt(*defaultStmt, condV.getType(), caseAttrs,
1007+
os);
9821008
}
983-
caseAttrs.push_back(caseAttr);
1009+
1010+
lastCaseBlock = builder.getBlock();
1011+
1012+
if (res.failed())
1013+
break;
9841014
}
9851015

9861016
os.addAttribute("cases", builder.getArrayAttr(caseAttrs));
@@ -1057,4 +1087,4 @@ void CIRGenFunction::buildReturnOfRValue(mlir::Location loc, RValue RV,
10571087
llvm_unreachable("NYI");
10581088
}
10591089
buildBranchThroughCleanup(loc, ReturnBlock());
1060-
}
1090+
}

clang/test/CIR/CodeGen/switch.cpp

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ void sw1(int a) {
1515
}
1616
}
1717
}
18-
1918
// CHECK: cir.func @_Z3sw1i
2019
// CHECK: cir.switch (%3 : !s32i) [
2120
// CHECK-NEXT: case (equal, 0) {
@@ -160,3 +159,102 @@ void sw7(int a) {
160159
// CHECK-NEXT: case (anyof, [3, 4, 5] : !s32i) {
161160
// CHECK-NEXT: cir.yield break
162161
// CHECK-NEXT: }
162+
163+
void sw8(int a) {
164+
switch (a)
165+
{
166+
case 3:
167+
break;
168+
case 4:
169+
default:
170+
break;
171+
}
172+
}
173+
174+
//CHECK: cir.func @_Z3sw8i
175+
//CHECK: case (equal, 3)
176+
//CHECK-NEXT: cir.yield break
177+
//CHECK-NEXT: },
178+
//CHECK-NEXT: case (equal, 4) {
179+
//CHECK-NEXT: cir.yield fallthrough
180+
//CHECK-NEXT: }
181+
//CHECK-NEXT: case (default) {
182+
//CHECK-NEXT: cir.yield break
183+
//CHECK-NEXT: }
184+
185+
void sw9(int a) {
186+
switch (a)
187+
{
188+
case 3:
189+
break;
190+
default:
191+
case 4:
192+
break;
193+
}
194+
}
195+
196+
//CHECK: cir.func @_Z3sw9i
197+
//CHECK: case (equal, 3) {
198+
//CHECK-NEXT: cir.yield break
199+
//CHECK-NEXT: }
200+
//CHECK-NEXT: case (default) {
201+
//CHECK-NEXT: cir.yield fallthrough
202+
//CHECK-NEXT: }
203+
//CHECK: case (equal, 4)
204+
//CHECK-NEXT: cir.yield break
205+
//CHECK-NEXT: }
206+
207+
void sw10(int a) {
208+
switch (a)
209+
{
210+
case 3:
211+
break;
212+
case 4:
213+
default:
214+
case 5:
215+
break;
216+
}
217+
}
218+
219+
//CHECK: cir.func @_Z4sw10i
220+
//CHECK: case (equal, 3)
221+
//CHECK-NEXT: cir.yield break
222+
//CHECK-NEXT: },
223+
//CHECK-NEXT: case (equal, 4) {
224+
//CHECK-NEXT: cir.yield fallthrough
225+
//CHECK-NEXT: }
226+
//CHECK-NEXT: case (default) {
227+
//CHECK-NEXT: cir.yield fallthrough
228+
//CHECK-NEXT: }
229+
//CHECK-NEXT: case (equal, 5) {
230+
//CHECK-NEXT: cir.yield break
231+
//CHECK-NEXT: }
232+
233+
void sw11(int a) {
234+
switch (a)
235+
{
236+
case 3:
237+
break;
238+
case 4:
239+
case 5:
240+
default:
241+
case 6:
242+
case 7:
243+
break;
244+
}
245+
}
246+
247+
//CHECK: cir.func @_Z4sw11i
248+
//CHECK: case (equal, 3)
249+
//CHECK-NEXT: cir.yield break
250+
//CHECK-NEXT: },
251+
//CHECK-NEXT: case (anyof, [4, 5] : !s32i) {
252+
//CHECK-NEXT: cir.yield fallthrough
253+
//CHECK-NEXT: }
254+
//CHECK-NEXT: case (default) {
255+
//CHECK-NEXT: cir.yield fallthrough
256+
//CHECK-NEXT: }
257+
//CHECK-NEXT: case (anyof, [6, 7] : !s32i) {
258+
//CHECK-NEXT: cir.yield break
259+
//CHECK-NEXT: }
260+

0 commit comments

Comments
 (0)