Skip to content

[OpenACC][CIR] Implement 'num_gangs' lowering #137216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 120 additions & 29 deletions clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,78 @@ class OpenACCClauseCIREmitter final
.CaseLower("radeon", mlir::acc::DeviceType::Radeon);
}

// Handle a clause affected by the 'device-type' to the point that they need
// to have the attributes added in the correct/corresponding order, such as
// 'num_workers' or 'vector_length' on a compute construct. For cases where we
// don't have an expression 'argument' that needs to be added to an operand
// and only care about the 'device-type' list, we can use this with 'argument'
// as 'std::nullopt'. If 'argument' is NOT 'std::nullopt' (that is, has a
// value), argCollection must also be non-null. For cases where we don't have
// an argument that needs to be added to an additional one (such as asyncOnly)
// we can use this with 'argument' as std::nullopt.
mlir::ArrayAttr handleDeviceTypeAffectedClause(
mlir::ArrayAttr existingDeviceTypes,
std::optional<mlir::Value> argument = std::nullopt,
mlir::MutableOperandRange *argCollection = nullptr) {
// Overload of this function that only returns the device-types list.
mlir::ArrayAttr
handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes) {
mlir::ValueRange argument;
mlir::MutableOperandRange range{operation};

return handleDeviceTypeAffectedClause(existingDeviceTypes, argument, range);
}
// Overload of this function for when 'segments' aren't necessary.
mlir::ArrayAttr
handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes,
mlir::ValueRange argument,
mlir::MutableOperandRange argCollection) {
llvm::SmallVector<int32_t> segments;
assert(argument.size() <= 1 &&
"Overload only for cases where segments don't need to be added");
return handleDeviceTypeAffectedClause(existingDeviceTypes, argument,
argCollection, segments);
}

// Handle a clause affected by the 'device_type' to the point that they need
// to have attributes added in the correct/corresponding order, such as
// 'num_workers' or 'vector_length' on a compute construct. The 'argument' is
// a collection of operands that need to be appended to the `argCollection` as
// we're adding a 'device_type' entry. If there is more than 0 elements in
// the 'argument', the collection must be non-null, as it is needed to add to
// it.
// As some clauses, such as 'num_gangs' or 'wait' require a 'segments' list to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me what segments is. Can you add more detail in the comment explaining? It looks like you're pushing the number of arguments in the clause being handled? Can there be a segment with zero arguments between non-zero segments?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. 'segments' are a little weird, they are a little bit MLIR/OpenACC-Dialect specific perhaps. I'll try to improve the comment.

As far as zero-arguments between non-zero segments, my understanding is no.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic! Thanks for the expanded explanation.

// be maintained, this takes a list of segments that will be updated with the
// proper counts as 'argument' elements are added.
//
// In MLIR, the 'operands' are stored as a large array, with a separate array
// of 'segments' that show which 'operand' applies to which 'operand-kind'.
// That is, a 'num_workers' operand-kind or 'num_vectors' operand-kind.
//
// So the operands array might have 4 elements, but the 'segments' array will
// be something like:
//
// {0, 0, 0, 2, 0, 1, 1, 0, 0...}
//
// Where each position belongs to a specific 'operand-kind'. So that
// specifies that whichever operand-kind corresponds with index '3' has 2
// elements, and should take the 1st 2 operands off the list (since all
// preceding values are 0). operand-kinds corresponding to 5 and 6 each have
// 1 element.
//
// Fortunately, the `MutableOperandRange` append function actually takes care
// of that for us at the 'top level'.
//
// However, in cases like `num_gangs' or 'wait', where each individual
// 'element' might be itself array-like, there is a separate 'segments' array
// for them. So in the case of:
//
// device_type(nvidia, radeon) num_gangs(1, 2, 3)
//
// We have to emit that as TWO arrays into the IR (where the device_type is an
// attribute), so they look like:
//
// num_gangs({One : i32, Two : i32, Three : i32} [#acc.device_type<nvidia>],\
// {One : i32, Two : i32, Three : i32} [#acc.device_type<radeon>])
//
// When stored in the 'operands' list, the top-level 'segment' for
// 'num_gangs' just shows 6 elements. In order to get the array-like
// apperance, the 'numGangsSegments' list is kept as well. In the above case,
// we've inserted 6 operands, so the 'numGangsSegments' must contain 2
// elements, 1 per array, and each will have a value of 3. The verifier will
// ensure that the collections counts are correct.
mlir::ArrayAttr
handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes,
mlir::ValueRange argument,
mlir::MutableOperandRange argCollection,
llvm::SmallVector<int32_t> &segments) {
llvm::SmallVector<mlir::Attribute> deviceTypes;

// Collect the 'existing' device-type attributes so we can re-create them
Expand All @@ -126,18 +185,18 @@ class OpenACCClauseCIREmitter final
lastDeviceTypeClause->getArchitectures()) {
deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
builder.getContext(), decodeDeviceType(arch.getIdentifierInfo())));
if (argument) {
assert(argCollection);
argCollection->append(*argument);
if (!argument.empty()) {
argCollection.append(argument);
segments.push_back(argument.size());
}
}
} else {
// Else, we just add a single for 'none'.
deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
builder.getContext(), mlir::acc::DeviceType::None));
if (argument) {
assert(argCollection);
argCollection->append(*argument);
if (!argument.empty()) {
argCollection.append(argument);
segments.push_back(argument.size());
}
}

Expand Down Expand Up @@ -170,7 +229,8 @@ class OpenACCClauseCIREmitter final
break;
}
} else {
// Combined Constructs left.
// TODO: When we've implemented this for everything, switch this to an
// unreachable. Combined constructs remain.
return clauseNotImplemented(clause);
}
}
Expand Down Expand Up @@ -210,7 +270,8 @@ class OpenACCClauseCIREmitter final
// they just modify the other clauses IR. So setting of `lastDeviceType`
// (done above) is all we need.
} else {
// update, data, loop, routine, combined remain.
// TODO: When we've implemented this for everything, switch this to an
// unreachable. update, data, loop, routine, combined constructs remain.
return clauseNotImplemented(clause);
}
}
Expand All @@ -220,11 +281,12 @@ class OpenACCClauseCIREmitter final
mlir::MutableOperandRange range = operation.getNumWorkersMutable();
operation.setNumWorkersDeviceTypeAttr(handleDeviceTypeAffectedClause(
operation.getNumWorkersDeviceTypeAttr(),
createIntExpr(clause.getIntExpr()), &range));
createIntExpr(clause.getIntExpr()), range));
} else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
llvm_unreachable("num_workers not valid on serial");
} else {
// Combined Remain.
// TODO: When we've implemented this for everything, switch this to an
// unreachable. Combined constructs remain.
return clauseNotImplemented(clause);
}
}
Expand All @@ -234,11 +296,12 @@ class OpenACCClauseCIREmitter final
mlir::MutableOperandRange range = operation.getVectorLengthMutable();
operation.setVectorLengthDeviceTypeAttr(handleDeviceTypeAffectedClause(
operation.getVectorLengthDeviceTypeAttr(),
createIntExpr(clause.getIntExpr()), &range));
createIntExpr(clause.getIntExpr()), range));
} else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
llvm_unreachable("vector_length not valid on serial");
} else {
// Combined remain.
// TODO: When we've implemented this for everything, switch this to an
// unreachable. Combined constructs remain.
return clauseNotImplemented(clause);
}
}
Expand All @@ -252,10 +315,12 @@ class OpenACCClauseCIREmitter final
mlir::MutableOperandRange range = operation.getAsyncOperandsMutable();
operation.setAsyncOperandsDeviceTypeAttr(handleDeviceTypeAffectedClause(
operation.getAsyncOperandsDeviceTypeAttr(),
createIntExpr(clause.getIntExpr()), &range));
createIntExpr(clause.getIntExpr()), range));
}
} else {
// Data, enter data, exit data, update, wait, combined remain.
// TODO: When we've implemented this for everything, switch this to an
// unreachable. Combined constructs remain. Data, enter data, exit data,
// update, wait, combined constructs remain.
return clauseNotImplemented(clause);
}
}
Expand All @@ -272,7 +337,8 @@ class OpenACCClauseCIREmitter final
llvm_unreachable("var-list version of self shouldn't get here");
}
} else {
// update and combined remain.
// TODO: When we've implemented this for everything, switch this to an
// unreachable. If, combined constructs remain.
return clauseNotImplemented(clause);
}
}
Expand All @@ -286,7 +352,9 @@ class OpenACCClauseCIREmitter final
// 'if' applies to most of the constructs, but hold off on lowering them
// until we can write tests/know what we're doing with codegen to make
// sure we get it right.
// Enter data, exit data, host_data, update, wait, combined remain.
// TODO: When we've implemented this for everything, switch this to an
// unreachable. Enter data, exit data, host_data, update, wait, combined
// constructs remain.
return clauseNotImplemented(clause);
}
}
Expand All @@ -301,6 +369,29 @@ class OpenACCClauseCIREmitter final
}
}

void VisitNumGangsClause(const OpenACCNumGangsClause &clause) {
if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
llvm::SmallVector<mlir::Value> values;

for (const Expr *E : clause.getIntExprs())
values.push_back(createIntExpr(E));

llvm::SmallVector<int32_t> segments;
if (operation.getNumGangsSegments())
llvm::copy(*operation.getNumGangsSegments(),
std::back_inserter(segments));

mlir::MutableOperandRange range = operation.getNumGangsMutable();
operation.setNumGangsDeviceTypeAttr(handleDeviceTypeAffectedClause(
operation.getNumGangsDeviceTypeAttr(), values, range, segments));
operation.setNumGangsSegments(llvm::ArrayRef<int32_t>{segments});
} else {
// TODO: When we've implemented this for everything, switch this to an
// unreachable. Combined constructs remain.
return clauseNotImplemented(clause);
}
}

void VisitDefaultAsyncClause(const OpenACCDefaultAsyncClause &clause) {
if constexpr (isOneOfTypes<OpTy, SetOp>) {
operation.getDefaultAsyncMutable().append(
Expand Down
46 changes: 46 additions & 0 deletions clang/test/CIR/CodeGenOpenACC/kernels.c
Original file line number Diff line number Diff line change
Expand Up @@ -256,5 +256,51 @@ void acc_kernels(int cond) {
// CHECK-NEXT: acc.terminator
// CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<nvidia>, #acc.device_type<radeon>]}

#pragma acc kernels num_gangs(1)
{}
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
// CHECK-NEXT: acc.kernels num_gangs({%[[ONE_CAST]] : si32}) {
// CHECK-NEXT: acc.terminator
// CHECK-NEXT: } loc

#pragma acc kernels num_gangs(cond)
{}
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
// CHECK-NEXT: acc.kernels num_gangs({%[[CONV_CAST]] : si32}) {
// CHECK-NEXT: acc.terminator
// CHECK-NEXT: } loc

#pragma acc kernels num_gangs(1) device_type(radeon) num_gangs(cond)
{}
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
// CHECK-NEXT: acc.kernels num_gangs({%[[ONE_CAST]] : si32}, {%[[CONV_CAST]] : si32} [#acc.device_type<radeon>]) {
// CHECK-NEXT: acc.terminator
// CHECK-NEXT: } loc

#pragma acc kernels num_gangs(1) device_type(radeon) num_gangs(6)
{}
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
// CHECK-NEXT: %[[SIX_LITERAL:.*]] = cir.const #cir.int<6> : !s32i
// CHECK-NEXT: %[[SIX_CAST:.*]] = builtin.unrealized_conversion_cast %[[SIX_LITERAL]] : !s32i to si32
// CHECK-NEXT: acc.kernels num_gangs({%[[ONE_CAST]] : si32}, {%[[SIX_CAST]] : si32} [#acc.device_type<radeon>]) {
// CHECK-NEXT: acc.terminator
// CHECK-NEXT: } loc

#pragma acc kernels num_gangs(cond) device_type(radeon, nvidia) num_gangs(4)
{}
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
// CHECK-NEXT: %[[FOUR_LITERAL:.*]] = cir.const #cir.int<4> : !s32i
// CHECK-NEXT: %[[FOUR_CAST:.*]] = builtin.unrealized_conversion_cast %[[FOUR_LITERAL]] : !s32i to si32
// CHECK-NEXT: acc.kernels num_gangs({%[[CONV_CAST]] : si32}, {%[[FOUR_CAST]] : si32} [#acc.device_type<radeon>], {%[[FOUR_CAST]] : si32} [#acc.device_type<nvidia>]) {
// CHECK-NEXT: acc.terminator
// CHECK-NEXT: } loc

// CHECK-NEXT: cir.return
}
74 changes: 74 additions & 0 deletions clang/test/CIR/CodeGenOpenACC/parallel.c
Original file line number Diff line number Diff line change
Expand Up @@ -255,5 +255,79 @@ void acc_parallel(int cond) {
// CHECK-NEXT: acc.yield
// CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<nvidia>, #acc.device_type<radeon>]}

#pragma acc parallel num_gangs(1)
{}
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
// CHECK-NEXT: acc.parallel num_gangs({%[[ONE_CAST]] : si32}) {
// CHECK-NEXT: acc.yield
// CHECK-NEXT: } loc

#pragma acc parallel num_gangs(cond)
{}
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
// CHECK-NEXT: acc.parallel num_gangs({%[[CONV_CAST]] : si32}) {
// CHECK-NEXT: acc.yield
// CHECK-NEXT: } loc

#pragma acc parallel num_gangs(1, cond, 2)
{}
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !s32i
// CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_LITERAL]] : !s32i to si32
// CHECK-NEXT: acc.parallel num_gangs({%[[ONE_CAST]] : si32, %[[CONV_CAST]] : si32, %[[TWO_CAST]] : si32}) {
// CHECK-NEXT: acc.yield
// CHECK-NEXT: } loc

#pragma acc parallel num_gangs(1) device_type(radeon) num_gangs(cond)
{}
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
// CHECK-NEXT: acc.parallel num_gangs({%[[ONE_CAST]] : si32}, {%[[CONV_CAST]] : si32} [#acc.device_type<radeon>]) {
// CHECK-NEXT: acc.yield
// CHECK-NEXT: } loc

#pragma acc parallel num_gangs(1, cond, 2) device_type(radeon) num_gangs(4, 5, 6)
{}
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !s32i
// CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_LITERAL]] : !s32i to si32
// CHECK-NEXT: %[[FOUR_LITERAL:.*]] = cir.const #cir.int<4> : !s32i
// CHECK-NEXT: %[[FOUR_CAST:.*]] = builtin.unrealized_conversion_cast %[[FOUR_LITERAL]] : !s32i to si32
// CHECK-NEXT: %[[FIVE_LITERAL:.*]] = cir.const #cir.int<5> : !s32i
// CHECK-NEXT: %[[FIVE_CAST:.*]] = builtin.unrealized_conversion_cast %[[FIVE_LITERAL]] : !s32i to si32
// CHECK-NEXT: %[[SIX_LITERAL:.*]] = cir.const #cir.int<6> : !s32i
// CHECK-NEXT: %[[SIX_CAST:.*]] = builtin.unrealized_conversion_cast %[[SIX_LITERAL]] : !s32i to si32
// CHECK-NEXT: acc.parallel num_gangs({%[[ONE_CAST]] : si32, %[[CONV_CAST]] : si32, %[[TWO_CAST]] : si32}, {%[[FOUR_CAST]] : si32, %[[FIVE_CAST]] : si32, %[[SIX_CAST]] : si32} [#acc.device_type<radeon>])
// CHECK-NEXT: acc.yield
// CHECK-NEXT: } loc

#pragma acc parallel num_gangs(1, cond, 2) device_type(radeon, nvidia) num_gangs(4, 5, 6)
{}
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !s32i
// CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_LITERAL]] : !s32i to si32
// CHECK-NEXT: %[[FOUR_LITERAL:.*]] = cir.const #cir.int<4> : !s32i
// CHECK-NEXT: %[[FOUR_CAST:.*]] = builtin.unrealized_conversion_cast %[[FOUR_LITERAL]] : !s32i to si32
// CHECK-NEXT: %[[FIVE_LITERAL:.*]] = cir.const #cir.int<5> : !s32i
// CHECK-NEXT: %[[FIVE_CAST:.*]] = builtin.unrealized_conversion_cast %[[FIVE_LITERAL]] : !s32i to si32
// CHECK-NEXT: %[[SIX_LITERAL:.*]] = cir.const #cir.int<6> : !s32i
// CHECK-NEXT: %[[SIX_CAST:.*]] = builtin.unrealized_conversion_cast %[[SIX_LITERAL]] : !s32i to si32
// CHECK-NEXT: acc.parallel num_gangs({%[[ONE_CAST]] : si32, %[[CONV_CAST]] : si32, %[[TWO_CAST]] : si32}, {%[[FOUR_CAST]] : si32, %[[FIVE_CAST]] : si32, %[[SIX_CAST]] : si32} [#acc.device_type<radeon>], {%[[FOUR_CAST]] : si32, %[[FIVE_CAST]] : si32, %[[SIX_CAST]] : si32} [#acc.device_type<nvidia>])
// CHECK-NEXT: acc.yield
// CHECK-NEXT: } loc

// CHECK-NEXT: cir.return
}
Loading