@@ -557,63 +557,102 @@ mlir::LogicalResult CIRGenFunction::buildBreakStmt(const clang::BreakStmt &S) {
557557 return mlir::success ();
558558}
559559
560- mlir::LogicalResult CIRGenFunction::buildCaseStmt (const CaseStmt &S,
561- mlir::Type condType,
562- CaseAttr &caseEntry) {
563- assert ((!S.getRHS () || !S.caseStmtIsGNURange ()) &&
564- " case ranges not implemented" );
565- auto res = mlir::success ();
566-
560+ const CaseStmt *
561+ CIRGenFunction::foldCaseStmt (const clang::CaseStmt &S, mlir::Type condType,
562+ SmallVector<mlir::Attribute, 4 > &caseAttrs) {
567563 const CaseStmt *caseStmt = &S;
564+ const CaseStmt *lastCase = &S;
568565 SmallVector<mlir::Attribute, 4 > caseEltValueListAttr;
566+
569567 // Fold cascading cases whenever possible to simplify codegen a bit.
570- while (true ) {
568+ while (caseStmt) {
569+ lastCase = caseStmt;
571570 auto intVal = caseStmt->getLHS ()->EvaluateKnownConstInt (getContext ());
572571 caseEltValueListAttr.push_back (mlir::cir::IntAttr::get (condType, intVal));
573- if (isa<CaseStmt>(caseStmt->getSubStmt ()))
574- caseStmt = dyn_cast_or_null<CaseStmt>(caseStmt->getSubStmt ());
575- else
576- break ;
572+ caseStmt = dyn_cast_or_null<CaseStmt>(caseStmt->getSubStmt ());
577573 }
578574
579- auto caseValueList = builder.getArrayAttr (caseEltValueListAttr );
575+ auto *ctxt = builder.getContext ( );
580576
581- auto *ctx = builder.getContext ();
582- caseEntry = mlir::cir::CaseAttr::get (
583- ctx, caseValueList,
584- CaseOpKindAttr::get (ctx, caseEltValueListAttr.size () > 1
585- ? mlir::cir::CaseOpKind::Anyof
586- : mlir::cir::CaseOpKind::Equal));
577+ auto caseAttr = mlir::cir::CaseAttr::get (
578+ ctxt, builder.getArrayAttr (caseEltValueListAttr),
579+ CaseOpKindAttr::get (ctxt, caseEltValueListAttr.size () > 1
580+ ? mlir::cir::CaseOpKind::Anyof
581+ : mlir::cir::CaseOpKind::Equal));
587582
588- {
589- mlir::OpBuilder::InsertionGuard guardCase (builder);
590- res = buildStmt (
591- caseStmt->getSubStmt (),
592- /* useCurrentScope=*/ !isa<CompoundStmt>(caseStmt->getSubStmt ()));
593- }
583+ caseAttrs.push_back (caseAttr);
594584
595- // TODO: likelihood
596- return res;
585+ return lastCase;
597586}
598587
599- mlir::LogicalResult CIRGenFunction::buildDefaultStmt (const DefaultStmt &S,
600- mlir::Type condType,
601- CaseAttr &caseEntry) {
588+ void CIRGenFunction::insertFallthrough (const clang::Stmt &S) {
589+ builder.create <YieldOp>(
590+ getLoc (S.getBeginLoc ()),
591+ mlir::cir::YieldOpKindAttr::get (builder.getContext (),
592+ mlir::cir::YieldOpKind::Fallthrough),
593+ mlir::ValueRange ({}));
594+ }
595+
596+ template <typename T>
597+ mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade (
598+ const T *stmt, mlir::Type condType,
599+ SmallVector<mlir::Attribute, 4 > &caseAttrs, mlir::OperationState &os) {
600+
601+ assert ((isa<CaseStmt, DefaultStmt>(stmt)) &&
602+ " only case or default stmt go here" );
603+
602604 auto res = mlir::success ();
603- auto *ctx = builder.getContext ();
604- caseEntry = mlir::cir::CaseAttr::get (
605- ctx, builder.getArrayAttr ({}),
606- CaseOpKindAttr::get (ctx, mlir::cir::CaseOpKind::Default));
607- {
605+
606+ // Update scope information with the current region we are
607+ // emitting code for. This is useful to allow return blocks to be
608+ // automatically and properly placed during cleanup.
609+ auto *region = os.addRegion ();
610+ auto *block = builder.createBlock (region);
611+ builder.setInsertionPointToEnd (block);
612+ currLexScope->updateCurrentSwitchCaseRegion ();
613+
614+ auto *sub = stmt->getSubStmt ();
615+
616+ if (isa<DefaultStmt>(sub) && isa<CaseStmt>(stmt)) {
617+ insertFallthrough (*stmt);
618+ res =
619+ buildDefaultStmt (*dyn_cast<DefaultStmt>(sub), condType, caseAttrs, os);
620+ } else if (isa<CaseStmt>(sub) && isa<DefaultStmt>(stmt)) {
621+ insertFallthrough (*stmt);
622+ res = buildCaseStmt (*dyn_cast<CaseStmt>(sub), condType, caseAttrs, os);
623+ } else {
608624 mlir::OpBuilder::InsertionGuard guardCase (builder);
609- res = buildStmt (S.getSubStmt (),
610- /* useCurrentScope=*/ !isa<CompoundStmt>(S.getSubStmt ()));
625+ res = buildStmt (sub, /* useCurrentScope=*/ !isa<CompoundStmt>(sub));
611626 }
612627
613- // TODO: likelihood
614628 return res;
615629}
616630
631+ mlir::LogicalResult
632+ CIRGenFunction::buildCaseStmt (const CaseStmt &S, mlir::Type condType,
633+ SmallVector<mlir::Attribute, 4 > &caseAttrs,
634+ mlir::OperationState &os) {
635+ assert ((!S.getRHS () || !S.caseStmtIsGNURange ()) &&
636+ " case ranges not implemented" );
637+
638+ auto *caseStmt = foldCaseStmt (S, condType, caseAttrs);
639+ return buildCaseDefaultCascade (caseStmt, condType, caseAttrs, os);
640+ }
641+
642+ mlir::LogicalResult
643+ CIRGenFunction::buildDefaultStmt (const DefaultStmt &S, mlir::Type condType,
644+ SmallVector<mlir::Attribute, 4 > &caseAttrs,
645+ mlir::OperationState &os) {
646+ auto ctxt = builder.getContext ();
647+
648+ auto defAttr = mlir::cir::CaseAttr::get (
649+ ctxt, builder.getArrayAttr ({}),
650+ CaseOpKindAttr::get (ctxt, mlir::cir::CaseOpKind::Default));
651+
652+ caseAttrs.push_back (defAttr);
653+ return buildCaseDefaultCascade (&S, condType, caseAttrs, os);
654+ }
655+
617656static mlir::LogicalResult buildLoopCondYield (mlir::OpBuilder &builder,
618657 mlir::Location loc,
619658 mlir::Value cond) {
@@ -958,29 +997,20 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
958997 }
959998
960999 auto *caseStmt = dyn_cast<CaseStmt>(c);
961- CaseAttr caseAttr;
962- {
963- mlir::OpBuilder::InsertionGuard guardCase (builder);
9641000
965- // Update scope information with the current region we are
966- // emitting code for. This is useful to allow return blocks to be
967- // automatically and properly placed during cleanup.
968- mlir::Region *caseRegion = os.addRegion ();
969- currLexScope->updateCurrentSwitchCaseRegion ();
970-
971- lastCaseBlock = builder.createBlock (caseRegion);
972- if (caseStmt)
973- res = buildCaseStmt (*caseStmt, condV.getType (), caseAttr);
974- else {
975- auto *defaultStmt = dyn_cast<DefaultStmt>(c);
976- assert (defaultStmt && " expected default stmt" );
977- res = buildDefaultStmt (*defaultStmt, condV.getType (), caseAttr);
978- }
979-
980- if (res.failed ())
981- break ;
1001+ if (caseStmt)
1002+ res = buildCaseStmt (*caseStmt, condV.getType (), caseAttrs, os);
1003+ else {
1004+ auto *defaultStmt = dyn_cast<DefaultStmt>(c);
1005+ assert (defaultStmt && " expected default stmt" );
1006+ res = buildDefaultStmt (*defaultStmt, condV.getType (), caseAttrs,
1007+ os);
9821008 }
983- caseAttrs.push_back (caseAttr);
1009+
1010+ lastCaseBlock = builder.getBlock ();
1011+
1012+ if (res.failed ())
1013+ break ;
9841014 }
9851015
9861016 os.addAttribute (" cases" , builder.getArrayAttr (caseAttrs));
@@ -1057,4 +1087,4 @@ void CIRGenFunction::buildReturnOfRValue(mlir::Location loc, RValue RV,
10571087 llvm_unreachable (" NYI" );
10581088 }
10591089 buildBranchThroughCleanup (loc, ReturnBlock ());
1060- }
1090+ }
0 commit comments