@@ -555,63 +555,102 @@ mlir::LogicalResult CIRGenFunction::buildBreakStmt(const clang::BreakStmt &S) {
555555 return mlir::success ();
556556}
557557
558- mlir::LogicalResult CIRGenFunction::buildCaseStmt (const CaseStmt &S,
559- mlir::Type condType,
560- CaseAttr &caseEntry) {
561- assert ((!S.getRHS () || !S.caseStmtIsGNURange ()) &&
562- " case ranges not implemented" );
563- auto res = mlir::success ();
564-
558+ const CaseStmt *
559+ CIRGenFunction::foldCaseStmt (const clang::CaseStmt &S, mlir::Type condType,
560+ SmallVector<mlir::Attribute, 4 > &caseAttrs) {
565561 const CaseStmt *caseStmt = &S;
562+ const CaseStmt *lastCase = &S;
566563 SmallVector<mlir::Attribute, 4 > caseEltValueListAttr;
564+
567565 // Fold cascading cases whenever possible to simplify codegen a bit.
568- while (true ) {
566+ while (caseStmt) {
567+ lastCase = caseStmt;
569568 auto intVal = caseStmt->getLHS ()->EvaluateKnownConstInt (getContext ());
570569 caseEltValueListAttr.push_back (mlir::cir::IntAttr::get (condType, intVal));
571- if (isa<CaseStmt>(caseStmt->getSubStmt ()))
572- caseStmt = dyn_cast_or_null<CaseStmt>(caseStmt->getSubStmt ());
573- else
574- break ;
570+ caseStmt = dyn_cast_or_null<CaseStmt>(caseStmt->getSubStmt ());
575571 }
576572
577- auto caseValueList = builder.getArrayAttr (caseEltValueListAttr );
573+ auto *ctxt = builder.getContext ( );
578574
579- auto *ctx = builder.getContext ();
580- caseEntry = mlir::cir::CaseAttr::get (
581- ctx, caseValueList,
582- CaseOpKindAttr::get (ctx, caseEltValueListAttr.size () > 1
583- ? mlir::cir::CaseOpKind::Anyof
584- : mlir::cir::CaseOpKind::Equal));
575+ auto caseAttr = mlir::cir::CaseAttr::get (
576+ ctxt, builder.getArrayAttr (caseEltValueListAttr),
577+ CaseOpKindAttr::get (ctxt, caseEltValueListAttr.size () > 1
578+ ? mlir::cir::CaseOpKind::Anyof
579+ : mlir::cir::CaseOpKind::Equal));
585580
586- {
587- mlir::OpBuilder::InsertionGuard guardCase (builder);
588- res = buildStmt (
589- caseStmt->getSubStmt (),
590- /* useCurrentScope=*/ !isa<CompoundStmt>(caseStmt->getSubStmt ()));
591- }
581+ caseAttrs.push_back (caseAttr);
592582
593- // TODO: likelihood
594- return res;
583+ return lastCase;
595584}
596585
597- mlir::LogicalResult CIRGenFunction::buildDefaultStmt (const DefaultStmt &S,
598- mlir::Type condType,
599- CaseAttr &caseEntry) {
586+ void CIRGenFunction::insertFallthrough (const clang::Stmt &S) {
587+ builder.create <YieldOp>(
588+ getLoc (S.getBeginLoc ()),
589+ mlir::cir::YieldOpKindAttr::get (builder.getContext (),
590+ mlir::cir::YieldOpKind::Fallthrough),
591+ mlir::ValueRange ({}));
592+ }
593+
594+ template <typename T>
595+ mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade (
596+ const T *stmt, mlir::Type condType,
597+ SmallVector<mlir::Attribute, 4 > &caseAttrs, mlir::OperationState &os) {
598+
599+ assert ((isa<CaseStmt, DefaultStmt>(stmt)) &&
600+ " only case or default stmt go here" );
601+
600602 auto res = mlir::success ();
601- auto *ctx = builder.getContext ();
602- caseEntry = mlir::cir::CaseAttr::get (
603- ctx, builder.getArrayAttr ({}),
604- CaseOpKindAttr::get (ctx, mlir::cir::CaseOpKind::Default));
605- {
603+
604+ // Update scope information with the current region we are
605+ // emitting code for. This is useful to allow return blocks to be
606+ // automatically and properly placed during cleanup.
607+ auto *region = os.addRegion ();
608+ auto *block = builder.createBlock (region);
609+ builder.setInsertionPointToEnd (block);
610+ currLexScope->updateCurrentSwitchCaseRegion ();
611+
612+ auto *sub = stmt->getSubStmt ();
613+
614+ if (isa<DefaultStmt>(sub) && isa<CaseStmt>(stmt)) {
615+ insertFallthrough (*stmt);
616+ res =
617+ buildDefaultStmt (*dyn_cast<DefaultStmt>(sub), condType, caseAttrs, os);
618+ } else if (isa<CaseStmt>(sub) && isa<DefaultStmt>(stmt)) {
619+ insertFallthrough (*stmt);
620+ res = buildCaseStmt (*dyn_cast<CaseStmt>(sub), condType, caseAttrs, os);
621+ } else {
606622 mlir::OpBuilder::InsertionGuard guardCase (builder);
607- res = buildStmt (S.getSubStmt (),
608- /* useCurrentScope=*/ !isa<CompoundStmt>(S.getSubStmt ()));
623+ res = buildStmt (sub, /* useCurrentScope=*/ !isa<CompoundStmt>(sub));
609624 }
610625
611- // TODO: likelihood
612626 return res;
613627}
614628
629+ mlir::LogicalResult
630+ CIRGenFunction::buildCaseStmt (const CaseStmt &S, mlir::Type condType,
631+ SmallVector<mlir::Attribute, 4 > &caseAttrs,
632+ mlir::OperationState &os) {
633+ assert ((!S.getRHS () || !S.caseStmtIsGNURange ()) &&
634+ " case ranges not implemented" );
635+
636+ auto *caseStmt = foldCaseStmt (S, condType, caseAttrs);
637+ return buildCaseDefaultCascade (caseStmt, condType, caseAttrs, os);
638+ }
639+
640+ mlir::LogicalResult
641+ CIRGenFunction::buildDefaultStmt (const DefaultStmt &S, mlir::Type condType,
642+ SmallVector<mlir::Attribute, 4 > &caseAttrs,
643+ mlir::OperationState &os) {
644+ auto ctxt = builder.getContext ();
645+
646+ auto defAttr = mlir::cir::CaseAttr::get (
647+ ctxt, builder.getArrayAttr ({}),
648+ CaseOpKindAttr::get (ctxt, mlir::cir::CaseOpKind::Default));
649+
650+ caseAttrs.push_back (defAttr);
651+ return buildCaseDefaultCascade (&S, condType, caseAttrs, os);
652+ }
653+
615654static mlir::LogicalResult buildLoopCondYield (mlir::OpBuilder &builder,
616655 mlir::Location loc,
617656 mlir::Value cond) {
@@ -956,29 +995,20 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
956995 }
957996
958997 auto *caseStmt = dyn_cast<CaseStmt>(c);
959- CaseAttr caseAttr;
960- {
961- mlir::OpBuilder::InsertionGuard guardCase (builder);
962998
963- // Update scope information with the current region we are
964- // emitting code for. This is useful to allow return blocks to be
965- // automatically and properly placed during cleanup.
966- mlir::Region *caseRegion = os.addRegion ();
967- currLexScope->updateCurrentSwitchCaseRegion ();
968-
969- lastCaseBlock = builder.createBlock (caseRegion);
970- if (caseStmt)
971- res = buildCaseStmt (*caseStmt, condV.getType (), caseAttr);
972- else {
973- auto *defaultStmt = dyn_cast<DefaultStmt>(c);
974- assert (defaultStmt && " expected default stmt" );
975- res = buildDefaultStmt (*defaultStmt, condV.getType (), caseAttr);
976- }
977-
978- if (res.failed ())
979- break ;
999+ if (caseStmt)
1000+ res = buildCaseStmt (*caseStmt, condV.getType (), caseAttrs, os);
1001+ else {
1002+ auto *defaultStmt = dyn_cast<DefaultStmt>(c);
1003+ assert (defaultStmt && " expected default stmt" );
1004+ res = buildDefaultStmt (*defaultStmt, condV.getType (), caseAttrs,
1005+ os);
9801006 }
981- caseAttrs.push_back (caseAttr);
1007+
1008+ lastCaseBlock = builder.getBlock ();
1009+
1010+ if (res.failed ())
1011+ break ;
9821012 }
9831013
9841014 os.addAttribute (" cases" , builder.getArrayAttr (caseAttrs));
@@ -1055,4 +1085,4 @@ void CIRGenFunction::buildReturnOfRValue(mlir::Location loc, RValue RV,
10551085 llvm_unreachable (" NYI" );
10561086 }
10571087 buildBranchThroughCleanup (loc, ReturnBlock ());
1058- }
1088+ }
0 commit comments