@@ -554,62 +554,102 @@ mlir::LogicalResult CIRGenFunction::buildBreakStmt(const clang::BreakStmt &S) {
554554 return mlir::success ();
555555}
556556
557- mlir::LogicalResult CIRGenFunction::buildCaseStmt (const CaseStmt &S,
558- mlir::Type condType,
559- CaseAttr &caseEntry) {
560- assert ((!S.getRHS () || !S.caseStmtIsGNURange ()) &&
561- " case ranges not implemented" );
562- auto res = mlir::success ();
563-
557+ const CaseStmt *
558+ CIRGenFunction::foldCaseStmt (const clang::CaseStmt &S, mlir::Type condType,
559+ SmallVector<mlir::Attribute, 4 > &caseAttrs) {
564560 const CaseStmt *caseStmt = &S;
561+ const CaseStmt *lastCase = &S;
565562 SmallVector<mlir::Attribute, 4 > caseEltValueListAttr;
563+
566564 // Fold cascading cases whenever possible to simplify codegen a bit.
567- while (true ) {
565+ while (caseStmt) {
566+ lastCase = caseStmt;
568567 auto intVal = caseStmt->getLHS ()->EvaluateKnownConstInt (getContext ());
569568 caseEltValueListAttr.push_back (mlir::cir::IntAttr::get (condType, intVal));
570- if (isa<CaseStmt>(caseStmt->getSubStmt ()))
571- caseStmt = dyn_cast_or_null<CaseStmt>(caseStmt->getSubStmt ());
572- else
573- break ;
569+ caseStmt = dyn_cast_or_null<CaseStmt>(caseStmt->getSubStmt ());
574570 }
575571
576- auto caseValueList = builder.getArrayAttr (caseEltValueListAttr );
572+ auto *ctxt = builder.getContext ( );
577573
578- auto *ctx = builder.getContext ();
579- caseEntry = mlir::cir::CaseAttr::get (
580- ctx, caseValueList,
581- CaseOpKindAttr::get (ctx, caseEltValueListAttr.size () > 1
582- ? mlir::cir::CaseOpKind::Anyof
583- : mlir::cir::CaseOpKind::Equal));
584- {
585- mlir::OpBuilder::InsertionGuard guardCase (builder);
586- res = buildStmt (
587- caseStmt->getSubStmt (),
588- /* useCurrentScope=*/ !isa<CompoundStmt>(caseStmt->getSubStmt ()));
589- }
574+ auto caseAttr = mlir::cir::CaseAttr::get (
575+ ctxt, builder.getArrayAttr (caseEltValueListAttr),
576+ CaseOpKindAttr::get (ctxt, caseEltValueListAttr.size () > 1
577+ ? mlir::cir::CaseOpKind::Anyof
578+ : mlir::cir::CaseOpKind::Equal));
590579
591- // TODO: likelihood
592- return res;
580+ caseAttrs.push_back (caseAttr);
581+
582+ return lastCase;
593583}
594584
595- mlir::LogicalResult CIRGenFunction::buildDefaultStmt (const DefaultStmt &S,
596- mlir::Type condType,
597- CaseAttr &caseEntry) {
585+ void CIRGenFunction::insertFallthrough (const clang::Stmt &S) {
586+ builder.create <YieldOp>(
587+ getLoc (S.getBeginLoc ()),
588+ mlir::cir::YieldOpKindAttr::get (builder.getContext (),
589+ mlir::cir::YieldOpKind::Fallthrough),
590+ mlir::ValueRange ({}));
591+ }
592+
593+ template <typename T>
594+ mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade (
595+ const T *stmt, mlir::Type condType,
596+ SmallVector<mlir::Attribute, 4 > &caseAttrs, mlir::OperationState &os) {
597+
598+ assert ((isa<CaseStmt, DefaultStmt>(stmt)) &&
599+ " only case or default stmt go here" );
600+
598601 auto res = mlir::success ();
599- auto *ctx = builder.getContext ();
600- caseEntry = mlir::cir::CaseAttr::get (
601- ctx, builder.getArrayAttr ({}),
602- CaseOpKindAttr::get (ctx, mlir::cir::CaseOpKind::Default));
603- {
602+
603+ // Update scope information with the current region we are
604+ // emitting code for. This is useful to allow return blocks to be
605+ // automatically and properly placed during cleanup.
606+ auto *region = os.addRegion ();
607+ auto *block = builder.createBlock (region);
608+ builder.setInsertionPointToEnd (block);
609+ currLexScope->updateCurrentSwitchCaseRegion ();
610+
611+ auto *sub = stmt->getSubStmt ();
612+
613+ if (isa<DefaultStmt>(sub) && isa<CaseStmt>(stmt)) {
614+ insertFallthrough (*stmt);
615+ res =
616+ buildDefaultStmt (*dyn_cast<DefaultStmt>(sub), condType, caseAttrs, os);
617+ } else if (isa<CaseStmt>(sub) && isa<DefaultStmt>(stmt)) {
618+ insertFallthrough (*stmt);
619+ res = buildCaseStmt (*dyn_cast<CaseStmt>(sub), condType, caseAttrs, os);
620+ } else {
604621 mlir::OpBuilder::InsertionGuard guardCase (builder);
605- res = buildStmt (S.getSubStmt (),
606- /* useCurrentScope=*/ !isa<CompoundStmt>(S.getSubStmt ()));
622+ res = buildStmt (sub, /* useCurrentScope=*/ !isa<CompoundStmt>(sub));
607623 }
608624
609- // TODO: likelihood
610625 return res;
611626}
612627
628+ mlir::LogicalResult
629+ CIRGenFunction::buildCaseStmt (const CaseStmt &S, mlir::Type condType,
630+ SmallVector<mlir::Attribute, 4 > &caseAttrs,
631+ mlir::OperationState &os) {
632+ assert ((!S.getRHS () || !S.caseStmtIsGNURange ()) &&
633+ " case ranges not implemented" );
634+
635+ auto *caseStmt = foldCaseStmt (S, condType, caseAttrs);
636+ return buildCaseDefaultCascade (caseStmt, condType, caseAttrs, os);
637+ }
638+
639+ mlir::LogicalResult
640+ CIRGenFunction::buildDefaultStmt (const DefaultStmt &S, mlir::Type condType,
641+ SmallVector<mlir::Attribute, 4 > &caseAttrs,
642+ mlir::OperationState &os) {
643+ auto ctxt = builder.getContext ();
644+
645+ auto defAttr = mlir::cir::CaseAttr::get (
646+ ctxt, builder.getArrayAttr ({}),
647+ CaseOpKindAttr::get (ctxt, mlir::cir::CaseOpKind::Default));
648+
649+ caseAttrs.push_back (defAttr);
650+ return buildCaseDefaultCascade (&S, condType, caseAttrs, os);
651+ }
652+
613653static mlir::LogicalResult buildLoopCondYield (mlir::OpBuilder &builder,
614654 mlir::Location loc,
615655 mlir::Value cond) {
@@ -954,29 +994,20 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
954994 }
955995
956996 auto *caseStmt = dyn_cast<CaseStmt>(c);
957- CaseAttr caseAttr;
958- {
959- mlir::OpBuilder::InsertionGuard guardCase (builder);
960-
961- // Update scope information with the current region we are
962- // emitting code for. This is useful to allow return blocks to be
963- // automatically and properly placed during cleanup.
964- mlir::Region *caseRegion = os.addRegion ();
965- currLexScope->updateCurrentSwitchCaseRegion ();
966-
967- lastCaseBlock = builder.createBlock (caseRegion);
968- if (caseStmt)
969- res = buildCaseStmt (*caseStmt, condV.getType (), caseAttr);
970- else {
971- auto *defaultStmt = dyn_cast<DefaultStmt>(c);
972- assert (defaultStmt && " expected default stmt" );
973- res = buildDefaultStmt (*defaultStmt, condV.getType (), caseAttr);
974- }
975997
976- if (res.failed ())
977- break ;
998+ if (caseStmt)
999+ res = buildCaseStmt (*caseStmt, condV.getType (), caseAttrs, os);
1000+ else {
1001+ auto *defaultStmt = dyn_cast<DefaultStmt>(c);
1002+ assert (defaultStmt && " expected default stmt" );
1003+ res = buildDefaultStmt (*defaultStmt, condV.getType (), caseAttrs,
1004+ os);
9781005 }
979- caseAttrs.push_back (caseAttr);
1006+
1007+ lastCaseBlock = builder.getBlock ();
1008+
1009+ if (res.failed ())
1010+ break ;
9801011 }
9811012
9821013 os.addAttribute (" cases" , builder.getArrayAttr (caseAttrs));
0 commit comments