Skip to content

Commit

Permalink
[mlir][transform] Consistent linalg transform op syntax for dynam…
Browse files Browse the repository at this point in the history
…ic index lists (#90897)

This patch is a first pass at making consistent syntax across the
`LinalgTransformOp`s that use dynamic index lists for size parameters.
Previously, there were two different forms: inline types in the list, or
place them in the functional style tuple. This patch goes for the
latter.

In order to do this, the `printPackedOrDynamicIndexList`,
`printDynamicIndexList` and their `parse` counterparts were modified so
that the types can be optionally provided to the corresponding custom
directives.

All affected ops now use tablegen `assemblyFormat`, so custom
`parse`/`print` functions have been removed. There are a couple ops that
will likely add dynamic size support, and once that happens it should be
made sure that the assembly remains consistent with the changes in this
patch.

The affected ops are as follows: `pack`, `pack_greedily`,
`tile_using_forall`. The `tile_using_for` and `vectorize` ops already
used this syntax, but their custom assembly was removed.

---------

Co-authored-by: Oleksandr "Alex" Zinenko <ftynse@gmail.com>
  • Loading branch information
srcarroll and ftynse authored May 8, 2024
1 parent c6efcc9 commit 2c1c676
Show file tree
Hide file tree
Showing 53 changed files with 210 additions and 323 deletions.
41 changes: 28 additions & 13 deletions mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -783,10 +783,9 @@ def PackOp : Op<Transform_Dialect, "structured.pack", [
let assemblyFormat = [{
$target
`packed_sizes` `=` custom<DynamicIndexList>($packed_sizes,
$static_packed_sizes,
type($packed_sizes))
$static_packed_sizes)
attr-dict
`:` functional-type($target, results)
`:` functional-type(operands, results)
}];

let builders = [
Expand Down Expand Up @@ -890,14 +889,13 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
$target
oilist(
`matmul_packed_sizes` `=` custom<DynamicIndexList>($matmul_packed_sizes,
$static_matmul_packed_sizes,
type($matmul_packed_sizes))
$static_matmul_packed_sizes)
(`matmul_padded_sizes_next_multiple_of` `=`
$matmul_padded_sizes_next_multiple_of^)?
`matmul_inner_dims_order` `=` $matmul_inner_dims_order
)
attr-dict
`:` functional-type($target, results)
`:` functional-type(operands, results)
}];
let hasVerifier = 1;

Expand Down Expand Up @@ -1899,7 +1897,17 @@ def TileUsingForOp : Op<Transform_Dialect, "structured.tile_using_for",
$scalableSizes)>,
];

let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
$target
`tile_sizes` custom<DynamicIndexList>(
$dynamic_sizes,
$static_sizes,
$scalable_sizes)
(`interchange` `=` $interchange^)?
attr-dict
`:` functional-type(operands, results)
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
Expand Down Expand Up @@ -2017,17 +2025,13 @@ def TileUsingForallOp :
let assemblyFormat = [{
$target oilist(
`num_threads` custom<PackedOrDynamicIndexList>($packed_num_threads,
type($packed_num_threads),
$num_threads,
type($num_threads),
$static_num_threads) |
`tile_sizes` custom<PackedOrDynamicIndexList>($packed_tile_sizes,
type($packed_tile_sizes),
$tile_sizes,
type($tile_sizes),
$static_tile_sizes))
(`(` `mapping` `=` $mapping^ `)`)? attr-dict
`:` functional-type($target, results)
`:` functional-type(operands, results)
}];
let hasVerifier = 1;

Expand Down Expand Up @@ -2162,7 +2166,18 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",

let results = (outs);

let hasCustomAssemblyFormat = 1;
// We use oilist here to elide the optional `vector_sizes` when empty list
// is passed.
let assemblyFormat = [{
$target oilist(
`vector_sizes` custom<DynamicIndexList>(
$vector_sizes,
$static_vector_sizes,
$scalable_sizes))
attr-dict
`:` type($target)(`,`type($vector_sizes)^)?
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
Expand Down
16 changes: 15 additions & 1 deletion mlir/include/mlir/Dialect/Transform/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ void printPackedOrDynamicIndexList(OpAsmPrinter &printer, Operation *op,
Value packed, Type packedType,
OperandRange values, TypeRange valueTypes,
DenseI64ArrayAttr integers);
inline void printPackedOrDynamicIndexList(OpAsmPrinter &printer, Operation *op,
Value packed, OperandRange values,
DenseI64ArrayAttr integers) {
printPackedOrDynamicIndexList(printer, op, packed, Type(), values,
TypeRange{}, integers);
}

/// Parser hook for custom directive in assemblyFormat.
///
Expand All @@ -47,7 +53,15 @@ void printPackedOrDynamicIndexList(OpAsmPrinter &printer, Operation *op,
ParseResult parsePackedOrDynamicIndexList(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &packed,
Type &packedType, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
SmallVectorImpl<Type> &valueTypes, DenseI64ArrayAttr &integers);
SmallVectorImpl<Type> *valueTypes, DenseI64ArrayAttr &integers);
inline ParseResult parsePackedOrDynamicIndexList(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &packed,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
DenseI64ArrayAttr &integers) {
Type packedType;
return parsePackedOrDynamicIndexList(parser, packed, packedType, values,
nullptr, integers);
}
} // namespace transform
} // namespace mlir

Expand Down
11 changes: 9 additions & 2 deletions mlir/include/mlir/Interfaces/ViewLikeInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,16 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
/// empty then assume that all indices are non-scalable.
void printDynamicIndexList(
OpAsmPrinter &printer, Operation *op, OperandRange values,
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
ArrayRef<bool> scalables = {},
ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
TypeRange valueTypes = TypeRange(),
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
inline void printDynamicIndexList(
OpAsmPrinter &printer, Operation *op, OperandRange values,
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
return printDynamicIndexList(printer, op, values, integers, {}, valueTypes,
delimiter);
}

/// Parser hook for custom directive in assemblyFormat.
///
Expand Down
154 changes: 0 additions & 154 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2823,86 +2823,6 @@ SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
return results;
}

// We want to parse `DenseI64ArrayAttr` using the short form without the
// `array` prefix to be consistent in the IR with `parseDynamicIndexList`.
ParseResult parseOptionalInterchange(OpAsmParser &parser,
OperationState &result) {
if (failed(parser.parseOptionalKeyword("interchange")))
return success();
if (failed(parser.parseEqual()))
return failure();
result.addAttribute(
transform::TileUsingForOp::getInterchangeAttrName(result.name),
DenseI64ArrayAttr::parse(parser, Type{}));
return success();
}

void printOptionalInterchange(OpAsmPrinter &p,
ArrayRef<int64_t> interchangeVals) {
if (!interchangeVals.empty()) {
p << " interchange = [";
llvm::interleaveComma(interchangeVals, p,
[&](int64_t integer) { p << integer; });
p << "]";
}
}

ParseResult transform::TileUsingForOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand target;
SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
DenseI64ArrayAttr staticSizes;
FunctionType functionalType;
llvm::SMLoc operandLoc;
DenseBoolArrayAttr scalableVals;

if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals) ||
parseOptionalInterchange(parser, result) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(functionalType))
return ParseResult::failure();

size_t numExpectedLoops =
staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0);
if (functionalType.getNumResults() != numExpectedLoops + 1) {
return parser.emitError(parser.getNameLoc())
<< "expected " << (numExpectedLoops + 1) << " result type(s)";
}
if (functionalType.getNumInputs() != dynamicSizes.size() + 1) {
return parser.emitError(operandLoc)
<< "expected " << dynamicSizes.size() + 1 << " operand type(s)";
}
if (parser.resolveOperand(target, functionalType.getInputs().front(),
result.operands) ||
parser.resolveOperands(dynamicSizes,
functionalType.getInputs().drop_front(),
operandLoc, result.operands)) {
return failure();
}

result.addAttribute(getScalableSizesAttrName(result.name), scalableVals);

result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
result.addTypes(functionalType.getResults());
return success();
}

void TileUsingForOp::print(OpAsmPrinter &p) {
p << ' ' << getTarget();
printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
/*valueTypes=*/{}, getScalableSizesAttr(),
OpAsmParser::Delimiter::Square);
printOptionalInterchange(p, getInterchange());
p.printOptionalAttrDict(
(*this)->getAttrs(),
/*elidedAttrs=*/{getInterchangeAttrName(getOperation()->getName()),
getScalableSizesAttrName(getOperation()->getName()),
getStaticSizesAttrName(getOperation()->getName())});
p << " : ";
p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
}

void transform::TileUsingForOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTarget(), effects);
Expand Down Expand Up @@ -3219,80 +3139,6 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
// VectorizeOp
//===----------------------------------------------------------------------===//

static const StringLiteral kVectorSizesKeyword = "vector_sizes";

ParseResult transform::VectorizeOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand target;
SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
DenseI64ArrayAttr staticSizes;
SmallVector<Type> operandTypes;
llvm::SMLoc operandLoc;
DenseBoolArrayAttr scalableVals;

if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc))
return ParseResult::failure();

if (succeeded(parser.parseOptionalKeyword(kVectorSizesKeyword))) {
if (failed(parseDynamicIndexList(parser, dynamicSizes, staticSizes,
scalableVals)))
return ParseResult::failure();
}

if (succeeded(parser.parseOptionalKeyword(
getVectorizeNdExtractAttrName(result.name))))
result.addAttribute(getVectorizeNdExtractAttrName(result.name),
parser.getBuilder().getUnitAttr());

if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonTypeList(operandTypes))
return ParseResult::failure();

if (operandTypes.size() != dynamicSizes.size() + 1) {
return parser.emitError(operandLoc)
<< "expected " << dynamicSizes.size() + 1 << " operand type(s)";
}
if (parser.resolveOperand(target, operandTypes.front(), result.operands) ||
parser.resolveOperands(dynamicSizes, ArrayRef(operandTypes).drop_front(),
operandLoc, result.operands)) {
return failure();
}

if (scalableVals)
result.addAttribute(getScalableSizesAttrName(result.name), scalableVals);
if (staticSizes)
result.addAttribute(getStaticVectorSizesAttrName(result.name), staticSizes);

return success();
}

void transform::VectorizeOp::print(OpAsmPrinter &p) {
p << ' ' << getTarget() << ' ';
if (!getMixedVectorSizes().empty()) {
p << kVectorSizesKeyword << ' ';
printDynamicIndexList(p, getOperation(), getVectorSizes(),
getStaticVectorSizesAttr(),
/*valueTypes=*/{}, getScalableSizesAttr(),
OpAsmParser::Delimiter::Square);
}

if (getVectorizeNdExtract())
p << getVectorizeNdExtractAttrName() << ' ';

p.printOptionalAttrDict(
(*this)->getAttrs(),
/*elidedAttrs=*/{
getScalableSizesAttrName(getOperation()->getName()),
getStaticVectorSizesAttrName(getOperation()->getName())});
p << " : ";
p << getTarget().getType();
if (!getVectorSizes().empty()) {
p << ", ";
llvm::interleaveComma(getVectorSizes(), p,
[&](Value operand) { p << operand.getType(); });
}
}

DiagnosedSilenceableFailure transform::VectorizeOp::apply(
transform::TransformRewriter &rewriter,
mlir::transform::TransformResults &transformResults,
Expand Down
19 changes: 12 additions & 7 deletions mlir/lib/Dialect/Transform/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ void mlir::transform::printPackedOrDynamicIndexList(
if (packed) {
assert(values.empty() && (!integers || integers.empty()) &&
"expected no values/integers");
printer << "*(" << packed << " : " << packedType << ")";
printer << "*(" << packed;
if (packedType) {
printer << " : " << packedType;
}
printer << ")";
return;
}
printDynamicIndexList(printer, op, values, integers, valueTypes);
Expand All @@ -29,19 +33,20 @@ void mlir::transform::printPackedOrDynamicIndexList(
ParseResult mlir::transform::parsePackedOrDynamicIndexList(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &packed,
Type &packedType, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
SmallVectorImpl<Type> &valueTypes, DenseI64ArrayAttr &integers) {
SmallVectorImpl<Type> *valueTypes, DenseI64ArrayAttr &integers) {
OpAsmParser::UnresolvedOperand packedOperand;
if (parser.parseOptionalStar().succeeded()) {
if (parser.parseLParen().failed() ||
parser.parseOperand(packedOperand).failed() ||
parser.parseColonType(packedType).failed() ||
parser.parseRParen().failed()) {
parser.parseOperand(packedOperand).failed())
return failure();
if (packedType && (parser.parseColonType(packedType).failed()))
return failure();
if (parser.parseRParen().failed())
return failure();
}
packed.emplace(packedOperand);
integers = parser.getBuilder().getDenseI64ArrayAttr({});
return success();
}

return parseDynamicIndexList(parser, values, integers, &valueTypes);
return parseDynamicIndexList(parser, values, integers, valueTypes);
}
2 changes: 1 addition & 1 deletion mlir/lib/Interfaces/ViewLikeInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ static char getRightDelimiter(AsmParser::Delimiter delimiter) {
void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values,
ArrayRef<int64_t> integers,
TypeRange valueTypes, ArrayRef<bool> scalables,
ArrayRef<bool> scalables, TypeRange valueTypes,
AsmParser::Delimiter delimiter) {
char leftDelimiter = getLeftDelimiter(delimiter);
char rightDelimiter = getRightDelimiter(delimiter);
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/LLVM/transform-e2e.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func.func @matmul_tensors(
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.consumed}) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!transform.any_op) -> !transform.any_op
%1, %loops:3 = transform.structured.tile_using_for %0 [2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
%1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
%2 = transform.get_parent_op %1 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize_children_and_apply_patterns %2 : (!transform.any_op) -> !transform.any_op
%b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func.func @KCRS_to_KCRSsr(%arg0: tensor<1x1x128x64xf32>, %arg1: tensor<1x1x4x8x8
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:4 = transform.structured.tile_using_for %0 [1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
%1, %loops:4 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
Expand All @@ -54,7 +54,7 @@ func.func @pad_and_pack(%arg0: tensor<13x15xf32>, %arg1: tensor<2x8x8x2xf32>, %a
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:2 = transform.structured.tile_using_for %0 [1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
Expand Down Expand Up @@ -85,7 +85,7 @@ func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>)
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:2 = transform.structured.tile_using_for %0 [1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
Loading

0 comments on commit 2c1c676

Please sign in to comment.