1212#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
1313#include " mlir/IR/Builders.h"
1414#include " mlir/IR/TypeUtilities.h"
15+ #include " mlir/Interfaces/ViewLikeInterface.h"
1516
1617#include " llvm/Support/Debug.h"
1718
@@ -112,6 +113,68 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
112113// ===----------------------------------------------------------------------===//
113114// XeGPU_CreateNdDescOp
114115// ===----------------------------------------------------------------------===//
116+
117+ void CreateNdDescOp::build (OpBuilder &builder, OperationState &state,
118+ Type tdesc, TypedValue<MemRefType> source) {
119+ [[maybe_unused]] auto ty = source.getType ();
120+ assert (ty.hasStaticShape () && " expecting a memref with static shape" );
121+
122+ build (builder, state, tdesc, source, ValueRange ({}) /* dynamic offsets */ ,
123+ ValueRange ({}) /* empty dynamic shape */ ,
124+ ValueRange ({}) /* empty dynamic strides */ ,
125+ DenseI64ArrayAttr ({}) /* const offsets */ ,
126+ DenseI64ArrayAttr ({}) /* empty const shape*/ ,
127+ DenseI64ArrayAttr ({}) /* empty const strides*/ );
128+ }
129+
130+ void CreateNdDescOp::build (OpBuilder &builder, OperationState &state,
131+ Type tdesc, TypedValue<MemRefType> source,
132+ llvm::ArrayRef<OpFoldResult> shape,
133+ llvm::ArrayRef<OpFoldResult> strides) {
134+ assert (shape.size () && strides.size () && shape.size () == strides.size () &&
135+ " Shape and strides must be present and of equal size for ui64 "
136+ " initialization." );
137+
138+ llvm::SmallVector<int64_t > staticShape;
139+ llvm::SmallVector<int64_t > staticStrides;
140+ llvm::SmallVector<Value> dynamicShape;
141+ llvm::SmallVector<Value> dynamicStrides;
142+
143+ dispatchIndexOpFoldResults (shape, dynamicShape, staticShape);
144+ dispatchIndexOpFoldResults (strides, dynamicStrides, staticStrides);
145+
146+ auto staticShapeAttr = builder.getDenseI64ArrayAttr (staticShape);
147+ auto staticStridesAttr = builder.getDenseI64ArrayAttr (staticStrides);
148+
149+ build (builder, state, tdesc, source, ValueRange ({}), dynamicShape,
150+ dynamicStrides, builder.getDenseI64ArrayAttr ({}), staticShapeAttr,
151+ staticStridesAttr);
152+ }
153+
154+ void CreateNdDescOp::build (OpBuilder &builder, OperationState &state,
155+ Type tdesc, TypedValue<IntegerType> source,
156+ llvm::ArrayRef<OpFoldResult> shape,
157+ llvm::ArrayRef<OpFoldResult> strides) {
158+ assert (shape.size () && strides.size () && shape.size () == strides.size () &&
159+ " Shape and strides must be present and of equal size for ui64 "
160+ " initialization." );
161+
162+ llvm::SmallVector<int64_t > staticShape;
163+ llvm::SmallVector<int64_t > staticStrides;
164+ llvm::SmallVector<Value> dynamicShape;
165+ llvm::SmallVector<Value> dynamicStrides;
166+
167+ dispatchIndexOpFoldResults (shape, dynamicShape, staticShape);
168+ dispatchIndexOpFoldResults (strides, dynamicStrides, staticStrides);
169+
170+ auto staticShapeAttr = builder.getDenseI64ArrayAttr (staticShape);
171+ auto staticStridesAttr = builder.getDenseI64ArrayAttr (staticStrides);
172+
173+ build (builder, state, tdesc, source, ValueRange ({}), dynamicShape,
174+ dynamicStrides, builder.getDenseI64ArrayAttr ({}), staticShapeAttr,
175+ staticStridesAttr);
176+ }
177+
115178void CreateNdDescOp::build (OpBuilder &builder, OperationState &state,
116179 Type tdesc, TypedValue<MemRefType> source,
117180 llvm::ArrayRef<OpFoldResult> offsets) {
@@ -125,8 +188,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
125188 build (builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */ ,
126189 ValueRange ({}) /* empty dynamic shape */ ,
127190 ValueRange ({}) /* empty dynamic strides */ ,
128- staticOffsets /* const offsets */ , {} /* empty const shape */ ,
129- {} /* empty const strides*/ );
191+ builder. getDenseI64ArrayAttr ( staticOffsets) /* const offsets */ ,
192+ {} /* empty const shape */ , {} /* empty const strides*/ );
130193}
131194
132195void CreateNdDescOp::build (OpBuilder &builder, OperationState &state,
@@ -197,6 +260,13 @@ LogicalResult CreateNdDescOp::verify() {
197260 invalidElemTy |= memrefTy.getElementType () != getElementType ();
198261 }
199262
263+ if (llvm::isa<IntegerType>(getSourceType ())) {
264+ // strides and shape must present for integer source.
265+ if (getMixedStrides ().empty () || getMixedSizes ().empty ())
266+ return emitOpError (" Expecting strides and shape to be present for "
267+ " integer source." );
268+ }
269+
200270 // mismatches among shape, strides, and offsets are
201271 // already handeled by OffsetSizeAndStrideOpInterface.
202272 // So they are not check here.
@@ -221,6 +291,53 @@ LogicalResult CreateNdDescOp::verify() {
221291 return success ();
222292}
223293
294+ ParseResult parseOptionalDynamicIndexList (
295+ OpAsmParser &parser,
296+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
297+ DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr ,
298+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
299+
300+ SmallVector<int64_t , 4 > integerVals;
301+ auto parseIntegerOrValue = [&]() {
302+ OpAsmParser::UnresolvedOperand operand;
303+ auto res = parser.parseOptionalOperand (operand);
304+
305+ if (res.has_value () && succeeded (res.value ())) {
306+ values.push_back (operand);
307+ integerVals.push_back (ShapedType::kDynamic );
308+ if (valueTypes && parser.parseColonType (valueTypes->emplace_back ()))
309+ return failure ();
310+ } else {
311+ int64_t integer;
312+ if (failed (parser.parseInteger (integer)))
313+ return failure ();
314+ integerVals.push_back (integer);
315+ }
316+ return success ();
317+ };
318+
319+ // If the optional values are given there must be left bracket
320+ if (parser.parseOptionalLSquare ().succeeded ()) {
321+ if (parser.parseCommaSeparatedList (parseIntegerOrValue) ||
322+ parser.parseRSquare ())
323+ return parser.emitError (parser.getNameLoc ())
324+ << " expected a list of SSA values or integers" ;
325+ integers = parser.getBuilder ().getDenseI64ArrayAttr (integerVals);
326+ return success ();
327+ }
328+
329+ return success ();
330+ }
331+
332+ void printOptionalDynamicIndexList (
333+ OpAsmPrinter &printer, Operation *op, OperandRange values,
334+ ArrayRef<int64_t > integers, TypeRange valueTypes = TypeRange(),
335+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
336+
337+ return printDynamicIndexList (printer, op, values, integers,
338+ /* scalableFlags=*/ {}, valueTypes, delimiter);
339+ }
340+
224341// ===----------------------------------------------------------------------===//
225342// XeGPU_PrefetchNdOp
226343// ===----------------------------------------------------------------------===//
0 commit comments