Skip to content

Commit 34bbfc2

Browse files
gitoleglanza
authored andcommitted
[CIR][CIRGen][Bugfix] Fixes switch-case sub statements (#232)
This PR fixes CIR generation for the `switch-case` cases like the following: ``` case 'a': default: ... ``` or ``` default: case 'a': ... ``` i.e. when the `default` clause is sub-statement of the `case` one and vice versa.
1 parent f089c2d commit 34bbfc2

File tree

3 files changed

+209
-66
lines changed

3 files changed

+209
-66
lines changed

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,8 @@ class CIRGenFunction : public CIRGenTypeCache {
577577
const CIRGenFunctionInfo *CurFnInfo;
578578
clang::QualType FnRetTy;
579579

580-
/// This is the current function or global initializer that is generated code for.
580+
/// This is the current function or global initializer that is generated code
581+
/// for.
581582
mlir::Operation *CurFn = nullptr;
582583

583584
/// Save Parameter Decl for coroutine.
@@ -593,7 +594,7 @@ class CIRGenFunction : public CIRGenTypeCache {
593594

594595
CIRGenModule &getCIRGenModule() { return CGM; }
595596

596-
mlir::Block* getCurFunctionEntryBlock() {
597+
mlir::Block *getCurFunctionEntryBlock() {
597598
auto Fn = dyn_cast<mlir::cir::FuncOp>(CurFn);
598599
assert(Fn && "other callables NYI");
599600
return &Fn.getRegion().front();
@@ -1120,13 +1121,26 @@ class CIRGenFunction : public CIRGenTypeCache {
11201121

11211122
mlir::Type getCIRType(const clang::QualType &type);
11221123

1124+
const CaseStmt *foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,
1125+
SmallVector<mlir::Attribute, 4> &caseAttrs);
1126+
1127+
void insertFallthrough(const clang::Stmt &S);
1128+
1129+
template <typename T>
1130+
mlir::LogicalResult
1131+
buildCaseDefaultCascade(const T *stmt, mlir::Type condType,
1132+
SmallVector<mlir::Attribute, 4> &caseAttrs,
1133+
mlir::OperationState &os);
1134+
11231135
mlir::LogicalResult buildCaseStmt(const clang::CaseStmt &S,
11241136
mlir::Type condType,
1125-
mlir::cir::CaseAttr &caseEntry);
1137+
SmallVector<mlir::Attribute, 4> &caseAttrs,
1138+
mlir::OperationState &op);
11261139

1127-
mlir::LogicalResult buildDefaultStmt(const clang::DefaultStmt &S,
1128-
mlir::Type condType,
1129-
mlir::cir::CaseAttr &caseEntry);
1140+
mlir::LogicalResult
1141+
buildDefaultStmt(const clang::DefaultStmt &S, mlir::Type condType,
1142+
SmallVector<mlir::Attribute, 4> &caseAttrs,
1143+
mlir::OperationState &op);
11301144

11311145
mlir::cir::FuncOp generateCode(clang::GlobalDecl GD, mlir::cir::FuncOp Fn,
11321146
const CIRGenFunctionInfo &FnInfo);

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

Lines changed: 90 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -554,62 +554,102 @@ mlir::LogicalResult CIRGenFunction::buildBreakStmt(const clang::BreakStmt &S) {
554554
return mlir::success();
555555
}
556556

557-
mlir::LogicalResult CIRGenFunction::buildCaseStmt(const CaseStmt &S,
558-
mlir::Type condType,
559-
CaseAttr &caseEntry) {
560-
assert((!S.getRHS() || !S.caseStmtIsGNURange()) &&
561-
"case ranges not implemented");
562-
auto res = mlir::success();
563-
557+
const CaseStmt *
558+
CIRGenFunction::foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,
559+
SmallVector<mlir::Attribute, 4> &caseAttrs) {
564560
const CaseStmt *caseStmt = &S;
561+
const CaseStmt *lastCase = &S;
565562
SmallVector<mlir::Attribute, 4> caseEltValueListAttr;
563+
566564
// Fold cascading cases whenever possible to simplify codegen a bit.
567-
while (true) {
565+
while (caseStmt) {
566+
lastCase = caseStmt;
568567
auto intVal = caseStmt->getLHS()->EvaluateKnownConstInt(getContext());
569568
caseEltValueListAttr.push_back(mlir::cir::IntAttr::get(condType, intVal));
570-
if (isa<CaseStmt>(caseStmt->getSubStmt()))
571-
caseStmt = dyn_cast_or_null<CaseStmt>(caseStmt->getSubStmt());
572-
else
573-
break;
569+
caseStmt = dyn_cast_or_null<CaseStmt>(caseStmt->getSubStmt());
574570
}
575571

576-
auto caseValueList = builder.getArrayAttr(caseEltValueListAttr);
572+
auto *ctxt = builder.getContext();
577573

578-
auto *ctx = builder.getContext();
579-
caseEntry = mlir::cir::CaseAttr::get(
580-
ctx, caseValueList,
581-
CaseOpKindAttr::get(ctx, caseEltValueListAttr.size() > 1
582-
? mlir::cir::CaseOpKind::Anyof
583-
: mlir::cir::CaseOpKind::Equal));
584-
{
585-
mlir::OpBuilder::InsertionGuard guardCase(builder);
586-
res = buildStmt(
587-
caseStmt->getSubStmt(),
588-
/*useCurrentScope=*/!isa<CompoundStmt>(caseStmt->getSubStmt()));
589-
}
574+
auto caseAttr = mlir::cir::CaseAttr::get(
575+
ctxt, builder.getArrayAttr(caseEltValueListAttr),
576+
CaseOpKindAttr::get(ctxt, caseEltValueListAttr.size() > 1
577+
? mlir::cir::CaseOpKind::Anyof
578+
: mlir::cir::CaseOpKind::Equal));
590579

591-
// TODO: likelihood
592-
return res;
580+
caseAttrs.push_back(caseAttr);
581+
582+
return lastCase;
593583
}
594584

595-
mlir::LogicalResult CIRGenFunction::buildDefaultStmt(const DefaultStmt &S,
596-
mlir::Type condType,
597-
CaseAttr &caseEntry) {
585+
void CIRGenFunction::insertFallthrough(const clang::Stmt &S) {
586+
builder.create<YieldOp>(
587+
getLoc(S.getBeginLoc()),
588+
mlir::cir::YieldOpKindAttr::get(builder.getContext(),
589+
mlir::cir::YieldOpKind::Fallthrough),
590+
mlir::ValueRange({}));
591+
}
592+
593+
template <typename T>
594+
mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(
595+
const T *stmt, mlir::Type condType,
596+
SmallVector<mlir::Attribute, 4> &caseAttrs, mlir::OperationState &os) {
597+
598+
assert((isa<CaseStmt, DefaultStmt>(stmt)) &&
599+
"only case or default stmt go here");
600+
598601
auto res = mlir::success();
599-
auto *ctx = builder.getContext();
600-
caseEntry = mlir::cir::CaseAttr::get(
601-
ctx, builder.getArrayAttr({}),
602-
CaseOpKindAttr::get(ctx, mlir::cir::CaseOpKind::Default));
603-
{
602+
603+
// Update scope information with the current region we are
604+
// emitting code for. This is useful to allow return blocks to be
605+
// automatically and properly placed during cleanup.
606+
auto *region = os.addRegion();
607+
auto *block = builder.createBlock(region);
608+
builder.setInsertionPointToEnd(block);
609+
currLexScope->updateCurrentSwitchCaseRegion();
610+
611+
auto *sub = stmt->getSubStmt();
612+
613+
if (isa<DefaultStmt>(sub) && isa<CaseStmt>(stmt)) {
614+
insertFallthrough(*stmt);
615+
res =
616+
buildDefaultStmt(*dyn_cast<DefaultStmt>(sub), condType, caseAttrs, os);
617+
} else if (isa<CaseStmt>(sub) && isa<DefaultStmt>(stmt)) {
618+
insertFallthrough(*stmt);
619+
res = buildCaseStmt(*dyn_cast<CaseStmt>(sub), condType, caseAttrs, os);
620+
} else {
604621
mlir::OpBuilder::InsertionGuard guardCase(builder);
605-
res = buildStmt(S.getSubStmt(),
606-
/*useCurrentScope=*/!isa<CompoundStmt>(S.getSubStmt()));
622+
res = buildStmt(sub, /*useCurrentScope=*/!isa<CompoundStmt>(sub));
607623
}
608624

609-
// TODO: likelihood
610625
return res;
611626
}
612627

628+
mlir::LogicalResult
629+
CIRGenFunction::buildCaseStmt(const CaseStmt &S, mlir::Type condType,
630+
SmallVector<mlir::Attribute, 4> &caseAttrs,
631+
mlir::OperationState &os) {
632+
assert((!S.getRHS() || !S.caseStmtIsGNURange()) &&
633+
"case ranges not implemented");
634+
635+
auto *caseStmt = foldCaseStmt(S, condType, caseAttrs);
636+
return buildCaseDefaultCascade(caseStmt, condType, caseAttrs, os);
637+
}
638+
639+
mlir::LogicalResult
640+
CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType,
641+
SmallVector<mlir::Attribute, 4> &caseAttrs,
642+
mlir::OperationState &os) {
643+
auto ctxt = builder.getContext();
644+
645+
auto defAttr = mlir::cir::CaseAttr::get(
646+
ctxt, builder.getArrayAttr({}),
647+
CaseOpKindAttr::get(ctxt, mlir::cir::CaseOpKind::Default));
648+
649+
caseAttrs.push_back(defAttr);
650+
return buildCaseDefaultCascade(&S, condType, caseAttrs, os);
651+
}
652+
613653
static mlir::LogicalResult buildLoopCondYield(mlir::OpBuilder &builder,
614654
mlir::Location loc,
615655
mlir::Value cond) {
@@ -954,29 +994,20 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
954994
}
955995

956996
auto *caseStmt = dyn_cast<CaseStmt>(c);
957-
CaseAttr caseAttr;
958-
{
959-
mlir::OpBuilder::InsertionGuard guardCase(builder);
960-
961-
// Update scope information with the current region we are
962-
// emitting code for. This is useful to allow return blocks to be
963-
// automatically and properly placed during cleanup.
964-
mlir::Region *caseRegion = os.addRegion();
965-
currLexScope->updateCurrentSwitchCaseRegion();
966-
967-
lastCaseBlock = builder.createBlock(caseRegion);
968-
if (caseStmt)
969-
res = buildCaseStmt(*caseStmt, condV.getType(), caseAttr);
970-
else {
971-
auto *defaultStmt = dyn_cast<DefaultStmt>(c);
972-
assert(defaultStmt && "expected default stmt");
973-
res = buildDefaultStmt(*defaultStmt, condV.getType(), caseAttr);
974-
}
975997

976-
if (res.failed())
977-
break;
998+
if (caseStmt)
999+
res = buildCaseStmt(*caseStmt, condV.getType(), caseAttrs, os);
1000+
else {
1001+
auto *defaultStmt = dyn_cast<DefaultStmt>(c);
1002+
assert(defaultStmt && "expected default stmt");
1003+
res = buildDefaultStmt(*defaultStmt, condV.getType(), caseAttrs,
1004+
os);
9781005
}
979-
caseAttrs.push_back(caseAttr);
1006+
1007+
lastCaseBlock = builder.getBlock();
1008+
1009+
if (res.failed())
1010+
break;
9801011
}
9811012

9821013
os.addAttribute("cases", builder.getArrayAttr(caseAttrs));

clang/test/CIR/CodeGen/switch.cpp

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ void sw1(int a) {
1515
}
1616
}
1717
}
18-
1918
// CHECK: cir.func @_Z3sw1i
2019
// CHECK: cir.switch (%3 : !s32i) [
2120
// CHECK-NEXT: case (equal, 0) {
@@ -160,3 +159,102 @@ void sw7(int a) {
160159
// CHECK-NEXT: case (anyof, [3, 4, 5] : !s32i) {
161160
// CHECK-NEXT: cir.yield break
162161
// CHECK-NEXT: }
162+
163+
void sw8(int a) {
164+
switch (a)
165+
{
166+
case 3:
167+
break;
168+
case 4:
169+
default:
170+
break;
171+
}
172+
}
173+
174+
//CHECK: cir.func @_Z3sw8i
175+
//CHECK: case (equal, 3)
176+
//CHECK-NEXT: cir.yield break
177+
//CHECK-NEXT: },
178+
//CHECK-NEXT: case (equal, 4) {
179+
//CHECK-NEXT: cir.yield fallthrough
180+
//CHECK-NEXT: }
181+
//CHECK-NEXT: case (default) {
182+
//CHECK-NEXT: cir.yield break
183+
//CHECK-NEXT: }
184+
185+
void sw9(int a) {
186+
switch (a)
187+
{
188+
case 3:
189+
break;
190+
default:
191+
case 4:
192+
break;
193+
}
194+
}
195+
196+
//CHECK: cir.func @_Z3sw9i
197+
//CHECK: case (equal, 3) {
198+
//CHECK-NEXT: cir.yield break
199+
//CHECK-NEXT: }
200+
//CHECK-NEXT: case (default) {
201+
//CHECK-NEXT: cir.yield fallthrough
202+
//CHECK-NEXT: }
203+
//CHECK: case (equal, 4)
204+
//CHECK-NEXT: cir.yield break
205+
//CHECK-NEXT: }
206+
207+
void sw10(int a) {
208+
switch (a)
209+
{
210+
case 3:
211+
break;
212+
case 4:
213+
default:
214+
case 5:
215+
break;
216+
}
217+
}
218+
219+
//CHECK: cir.func @_Z4sw10i
220+
//CHECK: case (equal, 3)
221+
//CHECK-NEXT: cir.yield break
222+
//CHECK-NEXT: },
223+
//CHECK-NEXT: case (equal, 4) {
224+
//CHECK-NEXT: cir.yield fallthrough
225+
//CHECK-NEXT: }
226+
//CHECK-NEXT: case (default) {
227+
//CHECK-NEXT: cir.yield fallthrough
228+
//CHECK-NEXT: }
229+
//CHECK-NEXT: case (equal, 5) {
230+
//CHECK-NEXT: cir.yield break
231+
//CHECK-NEXT: }
232+
233+
void sw11(int a) {
234+
switch (a)
235+
{
236+
case 3:
237+
break;
238+
case 4:
239+
case 5:
240+
default:
241+
case 6:
242+
case 7:
243+
break;
244+
}
245+
}
246+
247+
//CHECK: cir.func @_Z4sw11i
248+
//CHECK: case (equal, 3)
249+
//CHECK-NEXT: cir.yield break
250+
//CHECK-NEXT: },
251+
//CHECK-NEXT: case (anyof, [4, 5] : !s32i) {
252+
//CHECK-NEXT: cir.yield fallthrough
253+
//CHECK-NEXT: }
254+
//CHECK-NEXT: case (default) {
255+
//CHECK-NEXT: cir.yield fallthrough
256+
//CHECK-NEXT: }
257+
//CHECK-NEXT: case (anyof, [6, 7] : !s32i) {
258+
//CHECK-NEXT: cir.yield break
259+
//CHECK-NEXT: }
260+

0 commit comments

Comments
 (0)