Skip to content

Conversation

rolfmorel
Copy link
Contributor

Makes it possible to pass around the options to a pass inside a schedule.

The refactoring also makes it so that the pass manager and pass are only
constructed once per apply() of the transform op versus for each target
payload given to the op's apply().

Makes it possible to pass around the options to a pass inside a schedule.

The refactoring also makes it so that the pass manager and pass are only
constructed once per apply of the transform op versus for each target
payload given to the op.
@llvmbot
Copy link
Member

llvmbot commented Jun 3, 2025

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

Changes

Makes it possible to pass around the options to a pass inside a schedule.

The refactoring also makes it so that the pass manager and pass are only
constructed once per apply() of the transform op versus for each target
payload given to the op's apply().


Full diff: https://github.com/llvm/llvm-project/pull/142683.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+10-15)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+99-18)
  • (modified) mlir/test/Dialect/Transform/test-pass-application.mlir (+51-2)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index e4eb67c8e14ce..b042f5e436185 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -399,15 +399,15 @@ def ApplyLoopInvariantCodeMotionOp : TransformDialectOp<"apply_licm",
 }
 
 def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
-    [TransformOpInterface, TransformEachOpTrait,
-     FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let summary = "Applies the specified registered pass or pass pipeline";
   let description = [{
     This transform applies the specified pass or pass pipeline to the targeted
     ops. The name of the pass/pipeline is specified as a string attribute, as
     set during pass/pipeline registration. Optionally, pass options may be
-    specified as a string attribute. The pass options syntax is identical to the
-    one used with "mlir-opt".
+    specified as a string attribute with the option to pass the attribute as a
+    param. The pass options syntax is identical to the one used with "mlir-opt".
 
     This op first looks for a pass pipeline with the specified name. If no such
     pipeline exists, it looks for a pass with the specified name. If no such
@@ -420,20 +420,15 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
     of targeted ops.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$target,
+  let arguments = (ins Optional<TransformParamTypeInterface>:$dynamic_options,
+                       TransformHandleTypeInterface:$target,
                        StrAttr:$pass_name,
-                       DefaultValuedAttr<StrAttr, "\"\"">:$options);
+                       DefaultValuedAttr<StrAttr, "\"\"">:$static_options);
   let results = (outs TransformHandleTypeInterface:$result);
   let assemblyFormat = [{
-    $pass_name `to` $target attr-dict `:` functional-type(operands, results)
-  }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-      ::mlir::transform::TransformRewriter &rewriter,
-      ::mlir::Operation *target,
-      ::mlir::transform::ApplyToEachResultList &results,
-      ::mlir::transform::TransformState &state);
+    $pass_name (`with` `options` `=`
+      custom<ApplyRegisteredPassOptions>($dynamic_options, $static_options)^)?
+      `to` $target attr-dict `:` functional-type(operands, results)
   }];
 }
 
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 673743f22249a..536c3e14fe5c0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -53,6 +53,13 @@
 
 using namespace mlir;
 
+static ParseResult parseApplyRegisteredPassOptions(
+    OpAsmParser &parser,
+    std::optional<OpAsmParser::UnresolvedOperand> &dynamicOptions,
+    StringAttr &staticOptions);
+static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
+                                            Operation *op, Value dynamicOptions,
+                                            StringAttr staticOptions);
 static ParseResult parseSequenceOpOperands(
     OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
     Type &rootType,
@@ -766,17 +773,38 @@ void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
 // ApplyRegisteredPassOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
-    transform::TransformRewriter &rewriter, Operation *target,
-    ApplyToEachResultList &results, transform::TransformState &state) {
-  // Make sure that this transform is not applied to itself. Modifying the
-  // transform IR while it is being interpreted is generally dangerous. Even
-  // more so when applying passes because they may perform a wide range of IR
-  // modifications.
-  DiagnosedSilenceableFailure payloadCheck =
-      ensurePayloadIsSeparateFromTransform(*this, target);
-  if (!payloadCheck.succeeded())
-    return payloadCheck;
+void transform::ApplyRegisteredPassOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getDynamicOptionsMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  modifiesPayload(effects);
+}
+
+DiagnosedSilenceableFailure
+transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
+                                        transform::TransformResults &results,
+                                        transform::TransformState &state) {
+  // Check whether pass options are specified, either as a dynamic param or
+  // a static attribute. In either case, options are passed as a single string.
+  StringRef options;
+  if (auto dynamicOptions = getDynamicOptions()) {
+    ArrayRef<Attribute> dynamicOptionsParam = state.getParams(dynamicOptions);
+    if (dynamicOptionsParam.size() != 1) {
+      return emitSilenceableError()
+             << "options passed as a param must be a single value, got "
+             << dynamicOptionsParam.size();
+    }
+    if (auto optionsStrAttr = dyn_cast<StringAttr>(dynamicOptionsParam[0])) {
+      options = optionsStrAttr.getValue();
+    } else {
+      return emitSilenceableError()
+             << "options passed as a param must be a string, got "
+             << dynamicOptionsParam[0];
+    }
+  } else {
+    options = getStaticOptions();
+  }
 
   // Get pass or pass pipeline from registry.
   const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
@@ -786,9 +814,9 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
     return emitDefiniteFailure()
            << "unknown pass or pass pipeline: " << getPassName();
 
-  // Create pass manager and run the pass or pass pipeline.
+  // Create pass manager and add the pass or pass pipeline.
   PassManager pm(getContext());
-  if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
+  if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
         emitError(msg);
         return failure();
       }))) {
@@ -796,16 +824,69 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
            << "failed to add pass or pass pipeline to pipeline: "
            << getPassName();
   }
-  if (failed(pm.run(target))) {
-    auto diag = emitSilenceableError() << "pass pipeline failed";
-    diag.attachNote(target->getLoc()) << "target op";
-    return diag;
+
+  auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
+  for (Operation *target : targets) {
+    // Make sure that this transform is not applied to itself. Modifying the
+    // transform IR while it is being interpreted is generally dangerous. Even
+    // more so when applying passes because they may perform a wide range of IR
+    // modifications.
+    DiagnosedSilenceableFailure payloadCheck =
+        ensurePayloadIsSeparateFromTransform(*this, target);
+    if (!payloadCheck.succeeded())
+      return payloadCheck;
+
+    // Run the pass or pass pipeline on the current target operation.
+    if (failed(pm.run(target))) {
+      auto diag = emitSilenceableError() << "pass pipeline failed";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
   }
 
-  results.push_back(target);
+  // The applied pass will have directly modified the payload IR(s).
+  results.set(llvm::cast<OpResult>(getResult()), targets);
   return DiagnosedSilenceableFailure::success();
 }
 
+static ParseResult parseApplyRegisteredPassOptions(
+    OpAsmParser &parser,
+    std::optional<OpAsmParser::UnresolvedOperand> &dynamicOptions,
+    StringAttr &staticOptions) {
+  dynamicOptions = std::nullopt;
+  OpAsmParser::UnresolvedOperand dynamicOptionsOperand;
+  OptionalParseResult hasDynamicOptions =
+      parser.parseOptionalOperand(dynamicOptionsOperand);
+
+  if (hasDynamicOptions.has_value()) {
+    if (failed(hasDynamicOptions.value()))
+      return failure();
+
+    dynamicOptions = dynamicOptionsOperand;
+    return success();
+  }
+
+  OptionalParseResult hasStaticOptions =
+      parser.parseOptionalAttribute(staticOptions);
+  if (hasStaticOptions.has_value()) {
+    if (failed(hasStaticOptions.value()))
+      return failure();
+    return success();
+  }
+
+  return success();
+}
+
+static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
+                                            Operation *op, Value dynamicOptions,
+                                            StringAttr staticOptions) {
+  if (dynamicOptions) {
+    printer.printOperand(dynamicOptions);
+  } else if (!staticOptions.getValue().empty()) {
+    printer.printAttribute(staticOptions);
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // CastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 3a40b462b8270..e8e0f63b28096 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -79,7 +79,7 @@ module attributes {transform.with_named_sequence} {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     // expected-error @below {{failed to add pass or pass pipeline to pipeline: canonicalize}}
     // expected-error @below {{<Pass-Options-Parser>: no such option invalid-option}}
-    transform.apply_registered_pass "canonicalize" to %1 {options = "invalid-option=1"} : (!transform.any_op) -> !transform.any_op
+    transform.apply_registered_pass "canonicalize" with options = "invalid-option=1" to %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
@@ -94,7 +94,56 @@ func.func @valid_pass_option() {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_registered_pass "canonicalize" to %1 {options = "top-down=false"} : (!transform.any_op) -> !transform.any_op
+    transform.apply_registered_pass "canonicalize" with options = "top-down=false" to %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func @valid_dynamic_pass_option()
+func.func @valid_dynamic_pass_option() {
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %pass_options = transform.param.constant "top-down=false" -> !transform.any_param
+    transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+// -----
+
+func.func @invalid_pass_option_param() {
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %pass_options = transform.param.constant 42 -> !transform.any_param
+    // expected-error @below {{options passed as a param must be a string, got 42}}
+    transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
+    transform.apply_registered_pass "canonicalize" with options = "invalid-option=1" to %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @too_many_pass_option_params() {
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %x = transform.param.constant "x" -> !transform.any_param
+    %pass_options = transform.merge_handles %x, %x : !transform.any_param
+    // expected-error @below {{options passed as a param must be a single value, got 2}}
+    transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
     transform.yield
   }
 }

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for the concept
Overall LGTM but let's wait for another opinion

@fschlimb
Copy link
Contributor

fschlimb commented Jun 4, 2025

I like the possibility to have be able to provide options as SSA values.

It seems like one might want to mix static and dynamic options. Would it make sense to provide the options as a list rather than a string?

%option1 =...
transform.apply_registered_pass "canonicalize" with options = [%option1, "option2"] to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op

@rolfmorel
Copy link
Contributor Author

rolfmorel commented Jun 4, 2025

Indeed, being able to mix-and-match static arguments with those passed in dynamically - or being able to combine multiple orthogonal dynamic arguments - would be nice!

The suggested syntax of a list/ArrayAttr where elements are either strings, i.e. "option=value" pairs, or can be transform params makes sense to me. With the interpretation that these elements need to be joined by commas spaces to have a single options string to pass to the pass.

Switching to this syntax does break the (documented) property that the options argument is just the string one would pass to the pass on the commandline. On the other hand, as this string had to be statically provided anyway, you could always do the transformation to an array manually. We could keep this option available though: either the options argument is a StringAttr (maybe even coming in as a param) or it is an array (of StringAttr or SSA-values) which will be commaspace-joined.

I will have a go at updating the PR. Thanks @fschlimb!

@rolfmorel
Copy link
Contributor Author

Updated the PR so that the following syntax is accepted (no brackets as it matches the cmdline options more closely - as suggested by @fschlimb offline):

    %max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param
    %max_rewrites = transform.param.constant "max-num-rewrites=1" -> !transform.any_param
    %2 = transform.apply_registered_pass "canonicalize" with options = "top-down=false" %max_iter "test-convergence=true" %max_rewrites to %1 : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op

The PR is ready to be re-reviewed.

@rolfmorel rolfmorel force-pushed the transform-pass-param branch from 9acadec to 9529ea4 Compare June 6, 2025 09:44
@rolfmorel rolfmorel merged commit 4eeee41 into llvm:main Jun 6, 2025
6 of 7 checks passed
rolfmorel added a commit to libxsmm/tpp-mlir that referenced this pull request Jun 12, 2025
* llvm/llvm-project#139340
```
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp
```

* llvm/llvm-project#141466 &
llvm/llvm-project#141019
  * Add `BufferizationState &state` to `bufferize` and `getBuffer` 

* llvm/llvm-project#143159 &
llvm/llvm-project#142683 &
llvm/llvm-project#143779
  * Updates to `transform.apply_registered_pass` and its Python-bindings

* llvm/llvm-project#143217
* `tilingResult->mergeResult.replacements` ->
`tilingResult->replacements`

* llvm/llvm-project#140559 &
llvm/llvm-project#143871
* Change CHECK lines which expected `amx.` ops to `llvm.call_intrinsic`s
& fix which enables conversion again.
paul0403 added a commit to PennyLaneAI/catalyst that referenced this pull request Jul 28, 2025
**Context:**
Update llvm, mhlo and enzyme, 2025 Q3.
The latest pair of good versions, indicated by mhlo, is
tensorflow/mlir-hlo@1dd2e71
```
mhlo=1dd2e71331014ae0373f6bf900ce6be393357190
llvm=f8cb7987c64dcffb72414a40560055cb717dbf74
```

For Enzyme, we go to the latest release
https://github.com/EnzymeAD/Enzyme/releases/tag/v0.0.186
```
enzyme=v0.0.186
```
with commit `8c1a596158f6194f10e8ffd56a1660a61c54337e`

**Description of the Change:**
Miscellaneous:
1. `GreedyRewriteConfig.stuff = blah` ->
`GreedyRewriteConfig.setStuff(blah)`
llvm/llvm-project#137122
2. llvm gep op `inbounds` attribute is subsumed under a gep sign wrap
enum flag llvm/llvm-project#137272
3. `arith::Constant[Int, Float]Op` builders now have the same argument
order as other ops (output type first, then arguments)
llvm/llvm-project#144636 (note that Enzyme also
noticed this EnzymeAD/Enzyme#2379 😆 )
4. The `lookupOrCreateFn` functions now take in a builder instead of
instantiating a new one llvm/llvm-project#136421
5. `getStridedElementPtr` now takes in `rewriter` as the first argument
(instead of the last), like all the other utils
llvm/llvm-project#138984
6. The following functions now return a `LogicalResult`, and will be
caught by warnings as errors as `-Wunused-result`:
- `func::FuncOp.[insert, erase]Argument(s)`
llvm/llvm-project#137130
- `getBackwardSlice()` llvm/llvm-project#140961

Things related to `transform.apply_registered_pass` op:
1. It now takes in a `dynamic_options`
llvm/llvm-project#142683. We don't need to use
this as all our pass options are static.
2. The options it takes in are now dictionaries instead of strings
llvm/llvm-project#143159

Bufferization:
1. `bufferization.to_memref` op is renamed to `bufferization.to_buffer`
llvm/llvm-project#137180
3. `bufferization.to_tensor` op's builder now needs the result type to
be explicit llvm/llvm-project#142986. This is
also needed by a patched mhlo pass.
4. The `getBuffer()` methods take in a new arg for `BufferizationState`
llvm/llvm-project#141019,
llvm/llvm-project#141466
5. `UnknownTypeConverterFn` in bufferization options now takes in just a
type instead of a full value
llvm/llvm-project#144658

**Related GitHub Issues:** 
[sc-95176]
[sc-95664]

---------

Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants