@@ -95,19 +95,78 @@ class OpenACCClauseCIREmitter final
95
95
.CaseLower (" radeon" , mlir::acc::DeviceType::Radeon);
96
96
}
97
97
98
- // Handle a clause affected by the 'device-type' to the point that they need
99
- // to have the attributes added in the correct/corresponding order, such as
100
- // 'num_workers' or 'vector_length' on a compute construct. For cases where we
101
- // don't have an expression 'argument' that needs to be added to an operand
102
- // and only care about the 'device-type' list, we can use this with 'argument'
103
- // as 'std::nullopt'. If 'argument' is NOT 'std::nullopt' (that is, has a
104
- // value), argCollection must also be non-null. For cases where we don't have
105
- // an argument that needs to be added to an additional one (such as asyncOnly)
106
- // we can use this with 'argument' as std::nullopt.
107
- mlir::ArrayAttr handleDeviceTypeAffectedClause (
108
- mlir::ArrayAttr existingDeviceTypes,
109
- std::optional<mlir::Value> argument = std::nullopt,
110
- mlir::MutableOperandRange *argCollection = nullptr ) {
98
+ // Overload of this function that only returns the device-types list.
99
+ mlir::ArrayAttr
100
+ handleDeviceTypeAffectedClause (mlir::ArrayAttr existingDeviceTypes) {
101
+ mlir::ValueRange argument;
102
+ mlir::MutableOperandRange range{operation};
103
+
104
+ return handleDeviceTypeAffectedClause (existingDeviceTypes, argument, range);
105
+ }
106
+ // Overload of this function for when 'segments' aren't necessary.
107
+ mlir::ArrayAttr
108
+ handleDeviceTypeAffectedClause (mlir::ArrayAttr existingDeviceTypes,
109
+ mlir::ValueRange argument,
110
+ mlir::MutableOperandRange argCollection) {
111
+ llvm::SmallVector<int32_t > segments;
112
+ assert (argument.size () <= 1 &&
113
+ " Overload only for cases where segments don't need to be added" );
114
+ return handleDeviceTypeAffectedClause (existingDeviceTypes, argument,
115
+ argCollection, segments);
116
+ }
117
+
118
+ // Handle a clause affected by the 'device_type' to the point that they need
119
+ // to have attributes added in the correct/corresponding order, such as
120
+ // 'num_workers' or 'vector_length' on a compute construct. The 'argument' is
121
+ // a collection of operands that need to be appended to the `argCollection` as
122
+ // we're adding a 'device_type' entry. If there is more than 0 elements in
123
+ // the 'argument', the collection must be non-null, as it is needed to add to
124
+ // it.
125
+ // As some clauses, such as 'num_gangs' or 'wait' require a 'segments' list to
126
+ // be maintained, this takes a list of segments that will be updated with the
127
+ // proper counts as 'argument' elements are added.
128
+ //
129
+ // In MLIR, the 'operands' are stored as a large array, with a separate array
130
+ // of 'segments' that show which 'operand' applies to which 'operand-kind'.
131
+ // That is, a 'num_workers' operand-kind or 'num_vectors' operand-kind.
132
+ //
133
+ // So the operands array might have 4 elements, but the 'segments' array will
134
+ // be something like:
135
+ //
136
+ // {0, 0, 0, 2, 0, 1, 1, 0, 0...}
137
+ //
138
+ // Where each position belongs to a specific 'operand-kind'. So that
139
+ // specifies that whichever operand-kind corresponds with index '3' has 2
140
+ // elements, and should take the 1st 2 operands off the list (since all
141
+ // preceding values are 0). operand-kinds corresponding to 5 and 6 each have
142
+ // 1 element.
143
+ //
144
+ // Fortunately, the `MutableOperandRange` append function actually takes care
145
+ // of that for us at the 'top level'.
146
+ //
147
+ // However, in cases like `num_gangs' or 'wait', where each individual
148
+ // 'element' might be itself array-like, there is a separate 'segments' array
149
+ // for them. So in the case of:
150
+ //
151
+ // device_type(nvidia, radeon) num_gangs(1, 2, 3)
152
+ //
153
+ // We have to emit that as TWO arrays into the IR (where the device_type is an
154
+ // attribute), so they look like:
155
+ //
156
+ // num_gangs({One : i32, Two : i32, Three : i32} [#acc.device_type<nvidia>],\
157
+ // {One : i32, Two : i32, Three : i32} [#acc.device_type<radeon>])
158
+ //
159
+ // When stored in the 'operands' list, the top-level 'segment' for
160
+ // 'num_gangs' just shows 6 elements. In order to get the array-like
161
+ // apperance, the 'numGangsSegments' list is kept as well. In the above case,
162
+ // we've inserted 6 operands, so the 'numGangsSegments' must contain 2
163
+ // elements, 1 per array, and each will have a value of 3. The verifier will
164
+ // ensure that the collections counts are correct.
165
+ mlir::ArrayAttr
166
+ handleDeviceTypeAffectedClause (mlir::ArrayAttr existingDeviceTypes,
167
+ mlir::ValueRange argument,
168
+ mlir::MutableOperandRange argCollection,
169
+ llvm::SmallVector<int32_t > &segments) {
111
170
llvm::SmallVector<mlir::Attribute> deviceTypes;
112
171
113
172
// Collect the 'existing' device-type attributes so we can re-create them
@@ -126,18 +185,18 @@ class OpenACCClauseCIREmitter final
126
185
lastDeviceTypeClause->getArchitectures ()) {
127
186
deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
128
187
builder.getContext (), decodeDeviceType (arch.getIdentifierInfo ())));
129
- if (argument) {
130
- assert ( argCollection);
131
- argCollection-> append (* argument);
188
+ if (! argument. empty () ) {
189
+ argCollection. append (argument );
190
+ segments. push_back ( argument. size () );
132
191
}
133
192
}
134
193
} else {
135
194
// Else, we just add a single for 'none'.
136
195
deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
137
196
builder.getContext (), mlir::acc::DeviceType::None));
138
- if (argument) {
139
- assert ( argCollection);
140
- argCollection-> append (* argument);
197
+ if (! argument. empty () ) {
198
+ argCollection. append (argument );
199
+ segments. push_back ( argument. size () );
141
200
}
142
201
}
143
202
@@ -170,7 +229,8 @@ class OpenACCClauseCIREmitter final
170
229
break ;
171
230
}
172
231
} else {
173
- // Combined Constructs left.
232
+ // TODO: When we've implemented this for everything, switch this to an
233
+ // unreachable. Combined constructs remain.
174
234
return clauseNotImplemented (clause);
175
235
}
176
236
}
@@ -210,7 +270,8 @@ class OpenACCClauseCIREmitter final
210
270
// they just modify the other clauses IR. So setting of `lastDeviceType`
211
271
// (done above) is all we need.
212
272
} else {
213
- // update, data, loop, routine, combined remain.
273
+ // TODO: When we've implemented this for everything, switch this to an
274
+ // unreachable. update, data, loop, routine, combined constructs remain.
214
275
return clauseNotImplemented (clause);
215
276
}
216
277
}
@@ -220,11 +281,12 @@ class OpenACCClauseCIREmitter final
220
281
mlir::MutableOperandRange range = operation.getNumWorkersMutable ();
221
282
operation.setNumWorkersDeviceTypeAttr (handleDeviceTypeAffectedClause (
222
283
operation.getNumWorkersDeviceTypeAttr (),
223
- createIntExpr (clause.getIntExpr ()), & range));
284
+ createIntExpr (clause.getIntExpr ()), range));
224
285
} else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
225
286
llvm_unreachable (" num_workers not valid on serial" );
226
287
} else {
227
- // Combined Remain.
288
+ // TODO: When we've implemented this for everything, switch this to an
289
+ // unreachable. Combined constructs remain.
228
290
return clauseNotImplemented (clause);
229
291
}
230
292
}
@@ -234,11 +296,12 @@ class OpenACCClauseCIREmitter final
234
296
mlir::MutableOperandRange range = operation.getVectorLengthMutable ();
235
297
operation.setVectorLengthDeviceTypeAttr (handleDeviceTypeAffectedClause (
236
298
operation.getVectorLengthDeviceTypeAttr (),
237
- createIntExpr (clause.getIntExpr ()), & range));
299
+ createIntExpr (clause.getIntExpr ()), range));
238
300
} else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
239
301
llvm_unreachable (" vector_length not valid on serial" );
240
302
} else {
241
- // Combined remain.
303
+ // TODO: When we've implemented this for everything, switch this to an
304
+ // unreachable. Combined constructs remain.
242
305
return clauseNotImplemented (clause);
243
306
}
244
307
}
@@ -252,10 +315,12 @@ class OpenACCClauseCIREmitter final
252
315
mlir::MutableOperandRange range = operation.getAsyncOperandsMutable ();
253
316
operation.setAsyncOperandsDeviceTypeAttr (handleDeviceTypeAffectedClause (
254
317
operation.getAsyncOperandsDeviceTypeAttr (),
255
- createIntExpr (clause.getIntExpr ()), & range));
318
+ createIntExpr (clause.getIntExpr ()), range));
256
319
}
257
320
} else {
258
- // Data, enter data, exit data, update, wait, combined remain.
321
+ // TODO: When we've implemented this for everything, switch this to an
322
+ // unreachable. Combined constructs remain. Data, enter data, exit data,
323
+ // update, wait, combined constructs remain.
259
324
return clauseNotImplemented (clause);
260
325
}
261
326
}
@@ -272,7 +337,8 @@ class OpenACCClauseCIREmitter final
272
337
llvm_unreachable (" var-list version of self shouldn't get here" );
273
338
}
274
339
} else {
275
- // update and combined remain.
340
+ // TODO: When we've implemented this for everything, switch this to an
341
+ // unreachable. If, combined constructs remain.
276
342
return clauseNotImplemented (clause);
277
343
}
278
344
}
@@ -286,7 +352,9 @@ class OpenACCClauseCIREmitter final
286
352
// 'if' applies to most of the constructs, but hold off on lowering them
287
353
// until we can write tests/know what we're doing with codegen to make
288
354
// sure we get it right.
289
- // Enter data, exit data, host_data, update, wait, combined remain.
355
+ // TODO: When we've implemented this for everything, switch this to an
356
+ // unreachable. Enter data, exit data, host_data, update, wait, combined
357
+ // constructs remain.
290
358
return clauseNotImplemented (clause);
291
359
}
292
360
}
@@ -301,6 +369,29 @@ class OpenACCClauseCIREmitter final
301
369
}
302
370
}
303
371
372
+ void VisitNumGangsClause (const OpenACCNumGangsClause &clause) {
373
+ if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
374
+ llvm::SmallVector<mlir::Value> values;
375
+
376
+ for (const Expr *E : clause.getIntExprs ())
377
+ values.push_back (createIntExpr (E));
378
+
379
+ llvm::SmallVector<int32_t > segments;
380
+ if (operation.getNumGangsSegments ())
381
+ llvm::copy (*operation.getNumGangsSegments (),
382
+ std::back_inserter (segments));
383
+
384
+ mlir::MutableOperandRange range = operation.getNumGangsMutable ();
385
+ operation.setNumGangsDeviceTypeAttr (handleDeviceTypeAffectedClause (
386
+ operation.getNumGangsDeviceTypeAttr (), values, range, segments));
387
+ operation.setNumGangsSegments (llvm::ArrayRef<int32_t >{segments});
388
+ } else {
389
+ // TODO: When we've implemented this for everything, switch this to an
390
+ // unreachable. Combined constructs remain.
391
+ return clauseNotImplemented (clause);
392
+ }
393
+ }
394
+
304
395
void VisitDefaultAsyncClause (const OpenACCDefaultAsyncClause &clause) {
305
396
if constexpr (isOneOfTypes<OpTy, SetOp>) {
306
397
operation.getDefaultAsyncMutable ().append (
0 commit comments