Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,8 @@ class CIRGenFunction : public CIRGenTypeCache {
const CIRGenFunctionInfo *CurFnInfo;
clang::QualType FnRetTy;

/// This is the current function or global initializer that is generated code for.
/// This is the current function or global initializer that is generated code
/// for.
mlir::Operation *CurFn = nullptr;

/// Save Parameter Decl for coroutine.
Expand All @@ -593,7 +594,7 @@ class CIRGenFunction : public CIRGenTypeCache {

CIRGenModule &getCIRGenModule() { return CGM; }

mlir::Block* getCurFunctionEntryBlock() {
mlir::Block *getCurFunctionEntryBlock() {
auto Fn = dyn_cast<mlir::cir::FuncOp>(CurFn);
assert(Fn && "other callables NYI");
return &Fn.getRegion().front();
Expand Down Expand Up @@ -1120,13 +1121,26 @@ class CIRGenFunction : public CIRGenTypeCache {

mlir::Type getCIRType(const clang::QualType &type);

const CaseStmt *foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs);

void insertFallthrough(const clang::Stmt &S);

template <typename T>
mlir::LogicalResult
buildCaseDefaultCascade(const T *stmt, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &os);

mlir::LogicalResult buildCaseStmt(const clang::CaseStmt &S,
mlir::Type condType,
mlir::cir::CaseAttr &caseEntry);
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &op);

mlir::LogicalResult buildDefaultStmt(const clang::DefaultStmt &S,
mlir::Type condType,
mlir::cir::CaseAttr &caseEntry);
mlir::LogicalResult
buildDefaultStmt(const clang::DefaultStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &op);

mlir::cir::FuncOp generateCode(clang::GlobalDecl GD, mlir::cir::FuncOp Fn,
const CIRGenFunctionInfo &FnInfo);
Expand Down
149 changes: 90 additions & 59 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,62 +554,102 @@ mlir::LogicalResult CIRGenFunction::buildBreakStmt(const clang::BreakStmt &S) {
return mlir::success();
}

mlir::LogicalResult CIRGenFunction::buildCaseStmt(const CaseStmt &S,
mlir::Type condType,
CaseAttr &caseEntry) {
assert((!S.getRHS() || !S.caseStmtIsGNURange()) &&
"case ranges not implemented");
auto res = mlir::success();

const CaseStmt *
CIRGenFunction::foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs) {
const CaseStmt *caseStmt = &S;
const CaseStmt *lastCase = &S;
SmallVector<mlir::Attribute, 4> caseEltValueListAttr;

// Fold cascading cases whenever possible to simplify codegen a bit.
while (true) {
while (caseStmt) {
lastCase = caseStmt;
auto intVal = caseStmt->getLHS()->EvaluateKnownConstInt(getContext());
caseEltValueListAttr.push_back(mlir::cir::IntAttr::get(condType, intVal));
if (isa<CaseStmt>(caseStmt->getSubStmt()))
caseStmt = dyn_cast_or_null<CaseStmt>(caseStmt->getSubStmt());
else
break;
caseStmt = dyn_cast_or_null<CaseStmt>(caseStmt->getSubStmt());
}

auto caseValueList = builder.getArrayAttr(caseEltValueListAttr);
auto *ctxt = builder.getContext();

auto *ctx = builder.getContext();
caseEntry = mlir::cir::CaseAttr::get(
ctx, caseValueList,
CaseOpKindAttr::get(ctx, caseEltValueListAttr.size() > 1
? mlir::cir::CaseOpKind::Anyof
: mlir::cir::CaseOpKind::Equal));
{
mlir::OpBuilder::InsertionGuard guardCase(builder);
res = buildStmt(
caseStmt->getSubStmt(),
/*useCurrentScope=*/!isa<CompoundStmt>(caseStmt->getSubStmt()));
}
auto caseAttr = mlir::cir::CaseAttr::get(
ctxt, builder.getArrayAttr(caseEltValueListAttr),
CaseOpKindAttr::get(ctxt, caseEltValueListAttr.size() > 1
? mlir::cir::CaseOpKind::Anyof
: mlir::cir::CaseOpKind::Equal));

// TODO: likelihood
return res;
caseAttrs.push_back(caseAttr);

return lastCase;
}

mlir::LogicalResult CIRGenFunction::buildDefaultStmt(const DefaultStmt &S,
mlir::Type condType,
CaseAttr &caseEntry) {
void CIRGenFunction::insertFallthrough(const clang::Stmt &S) {
builder.create<YieldOp>(
getLoc(S.getBeginLoc()),
mlir::cir::YieldOpKindAttr::get(builder.getContext(),
mlir::cir::YieldOpKind::Fallthrough),
mlir::ValueRange({}));
}

template <typename T>
mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(
const T *stmt, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs, mlir::OperationState &os) {

assert((isa<CaseStmt, DefaultStmt>(stmt)) &&
"only case or default stmt go here");

auto res = mlir::success();
auto *ctx = builder.getContext();
caseEntry = mlir::cir::CaseAttr::get(
ctx, builder.getArrayAttr({}),
CaseOpKindAttr::get(ctx, mlir::cir::CaseOpKind::Default));
{

// Update scope information with the current region we are
// emitting code for. This is useful to allow return blocks to be
// automatically and properly placed during cleanup.
auto *region = os.addRegion();
auto *block = builder.createBlock(region);
builder.setInsertionPointToEnd(block);
currLexScope->updateCurrentSwitchCaseRegion();

auto *sub = stmt->getSubStmt();

if (isa<DefaultStmt>(sub) && isa<CaseStmt>(stmt)) {
insertFallthrough(*stmt);
res =
buildDefaultStmt(*dyn_cast<DefaultStmt>(sub), condType, caseAttrs, os);
} else if (isa<CaseStmt>(sub) && isa<DefaultStmt>(stmt)) {
insertFallthrough(*stmt);
res = buildCaseStmt(*dyn_cast<CaseStmt>(sub), condType, caseAttrs, os);
} else {
mlir::OpBuilder::InsertionGuard guardCase(builder);
res = buildStmt(S.getSubStmt(),
/*useCurrentScope=*/!isa<CompoundStmt>(S.getSubStmt()));
res = buildStmt(sub, /*useCurrentScope=*/!isa<CompoundStmt>(sub));
}

// TODO: likelihood
return res;
}

mlir::LogicalResult
CIRGenFunction::buildCaseStmt(const CaseStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &os) {
assert((!S.getRHS() || !S.caseStmtIsGNURange()) &&
"case ranges not implemented");

auto *caseStmt = foldCaseStmt(S, condType, caseAttrs);
return buildCaseDefaultCascade(caseStmt, condType, caseAttrs, os);
}

mlir::LogicalResult
CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &os) {
auto ctxt = builder.getContext();

auto defAttr = mlir::cir::CaseAttr::get(
ctxt, builder.getArrayAttr({}),
CaseOpKindAttr::get(ctxt, mlir::cir::CaseOpKind::Default));

caseAttrs.push_back(defAttr);
return buildCaseDefaultCascade(&S, condType, caseAttrs, os);
}

static mlir::LogicalResult buildLoopCondYield(mlir::OpBuilder &builder,
mlir::Location loc,
mlir::Value cond) {
Expand Down Expand Up @@ -954,29 +994,20 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
}

auto *caseStmt = dyn_cast<CaseStmt>(c);
CaseAttr caseAttr;
{
mlir::OpBuilder::InsertionGuard guardCase(builder);

// Update scope information with the current region we are
// emitting code for. This is useful to allow return blocks to be
// automatically and properly placed during cleanup.
mlir::Region *caseRegion = os.addRegion();
currLexScope->updateCurrentSwitchCaseRegion();

lastCaseBlock = builder.createBlock(caseRegion);
if (caseStmt)
res = buildCaseStmt(*caseStmt, condV.getType(), caseAttr);
else {
auto *defaultStmt = dyn_cast<DefaultStmt>(c);
assert(defaultStmt && "expected default stmt");
res = buildDefaultStmt(*defaultStmt, condV.getType(), caseAttr);
}

if (res.failed())
break;
if (caseStmt)
res = buildCaseStmt(*caseStmt, condV.getType(), caseAttrs, os);
else {
auto *defaultStmt = dyn_cast<DefaultStmt>(c);
assert(defaultStmt && "expected default stmt");
res = buildDefaultStmt(*defaultStmt, condV.getType(), caseAttrs,
os);
}
caseAttrs.push_back(caseAttr);

lastCaseBlock = builder.getBlock();

if (res.failed())
break;
}

os.addAttribute("cases", builder.getArrayAttr(caseAttrs));
Expand Down
100 changes: 99 additions & 1 deletion clang/test/CIR/CodeGen/switch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ void sw1(int a) {
}
}
}

// CHECK: cir.func @_Z3sw1i
// CHECK: cir.switch (%3 : !s32i) [
// CHECK-NEXT: case (equal, 0) {
Expand Down Expand Up @@ -160,3 +159,102 @@ void sw7(int a) {
// CHECK-NEXT: case (anyof, [3, 4, 5] : !s32i) {
// CHECK-NEXT: cir.yield break
// CHECK-NEXT: }

void sw8(int a) {
switch (a)
{
case 3:
break;
case 4:
default:
break;
}
}

//CHECK: cir.func @_Z3sw8i
//CHECK: case (equal, 3)
//CHECK-NEXT: cir.yield break
//CHECK-NEXT: },
//CHECK-NEXT: case (equal, 4) {
//CHECK-NEXT: cir.yield fallthrough
//CHECK-NEXT: }
//CHECK-NEXT: case (default) {
//CHECK-NEXT: cir.yield break
//CHECK-NEXT: }

void sw9(int a) {
switch (a)
{
case 3:
break;
default:
case 4:
break;
}
}

//CHECK: cir.func @_Z3sw9i
//CHECK: case (equal, 3) {
//CHECK-NEXT: cir.yield break
//CHECK-NEXT: }
//CHECK-NEXT: case (default) {
//CHECK-NEXT: cir.yield fallthrough
//CHECK-NEXT: }
//CHECK: case (equal, 4)
//CHECK-NEXT: cir.yield break
//CHECK-NEXT: }

void sw10(int a) {
switch (a)
{
case 3:
break;
case 4:
default:
case 5:
break;
}
}

//CHECK: cir.func @_Z4sw10i
//CHECK: case (equal, 3)
//CHECK-NEXT: cir.yield break
//CHECK-NEXT: },
//CHECK-NEXT: case (equal, 4) {
//CHECK-NEXT: cir.yield fallthrough
//CHECK-NEXT: }
//CHECK-NEXT: case (default) {
//CHECK-NEXT: cir.yield fallthrough
//CHECK-NEXT: }
//CHECK-NEXT: case (equal, 5) {
//CHECK-NEXT: cir.yield break
//CHECK-NEXT: }

void sw11(int a) {
switch (a)
{
case 3:
break;
case 4:
case 5:
default:
case 6:
case 7:
break;
}
}

//CHECK: cir.func @_Z4sw11i
//CHECK: case (equal, 3)
//CHECK-NEXT: cir.yield break
//CHECK-NEXT: },
//CHECK-NEXT: case (anyof, [4, 5] : !s32i) {
//CHECK-NEXT: cir.yield fallthrough
//CHECK-NEXT: }
//CHECK-NEXT: case (default) {
//CHECK-NEXT: cir.yield fallthrough
//CHECK-NEXT: }
//CHECK-NEXT: case (anyof, [6, 7] : !s32i) {
//CHECK-NEXT: cir.yield break
//CHECK-NEXT: }