Skip to content

Commit a896fe3

Browse files
Migrate gradient dialect to new one-shot bufferization (#1740)
**Context:** This work is based on #1027 . As part of the mlir update, the bufferization of the custom catalyst dialects need to migrate to the new one-shot bufferization interface, as opposed to the old pattern-rewrite style bufferization passes. See more context in #1027. The `Quantum` dialect was migrated in #1686 . The `Catalyst` dialect was migrated in #1708 . Note that #1139 refactors the gradient dialect's bufferization into preprocess, bufferization, and postprocess. Only the middle bufferization stage is supposed to be replaced by one-shot bufferization. **Description of the Change:** Migrate `Gradient` dialect to one-shot bufferization. **Benefits:** Align with mlir practices; one step closer to updating mlir. [sc-71487] --------- Co-authored-by: Tzung-Han Juang <tzunghan.juang@gmail.com>
1 parent d81d46e commit a896fe3

23 files changed

+935
-603
lines changed

doc/releases/changelog-dev.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@
212212
[(#1027)](https://github.com/PennyLaneAI/catalyst/pull/1027)
213213
[(#1686)](https://github.com/PennyLaneAI/catalyst/pull/1686)
214214
[(#1708)](https://github.com/PennyLaneAI/catalyst/pull/1708)
215+
[(#1740)](https://github.com/PennyLaneAI/catalyst/pull/1740)
215216

216217
* Redundant `OptionalAttr` is removed from `adjoint` argument in `QuantumOps.td` TableGen file
217218
[(#1746)](https://github.com/PennyLaneAI/catalyst/pull/1746)

frontend/catalyst/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def get_bufferization_stage(_options: CompileOptions) -> List[str]:
219219
"one-shot-bufferize{dialect-filter=memref}",
220220
"inline",
221221
"gradient-preprocess",
222-
"gradient-bufferize",
222+
"one-shot-bufferize{dialect-filter=gradient unknown-type-conversion=identity-layout-map}",
223223
"scf-bufferize",
224224
"convert-tensor-to-linalg", # tensor.pad
225225
"convert-elementwise-to-linalg", # Must be run before --arith-bufferize

mlir/include/Gradient/IR/GradientOps.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
#include "mlir/IR/OpImplementation.h"
2222
#include "mlir/IR/SymbolTable.h"
2323
#include "mlir/Interfaces/CallInterfaces.h"
24-
25-
#include "Gradient/IR/GradientInterfaces.h"
24+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2625

2726
#include "Gradient/IR/GradientDialect.h"
27+
#include "Gradient/IR/GradientInterfaces.h"
2828

2929
#define GET_OP_CLASSES
3030
#include "Gradient/IR/GradientOps.h.inc"

mlir/include/Gradient/IR/GradientOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
include "mlir/Interfaces/FunctionInterfaces.td"
1919
include "mlir/Interfaces/CallInterfaces.td"
20+
include "mlir/Interfaces/ControlFlowInterfaces.td"
2021
include "mlir/IR/SymbolInterfaces.td"
2122
include "mlir/IR/BuiltinAttributes.td"
2223
include "mlir/IR/OpBase.td"
@@ -306,7 +307,7 @@ def ForwardOp : Gradient_Op<"forward",
306307
OptionalAttr<DictArrayAttr>: $res_attrs
307308
);
308309

309-
let regions = (region AnyRegion: $body);
310+
let regions = (region MaxSizedRegion<1>: $body);
310311

311312
let builders = [OpBuilder<(ins
312313
"mlir::StringRef":$name, "mlir::FunctionType":$type,
@@ -362,7 +363,7 @@ def ReverseOp : Gradient_Op<"reverse",
362363
OptionalAttr<DictArrayAttr>: $res_attrs
363364
);
364365

365-
let regions = (region AnyRegion: $body);
366+
let regions = (region MaxSizedRegion<1>: $body);
366367

367368
let builders = [OpBuilder<(ins
368369
"mlir::StringRef":$name, "mlir::FunctionType":$type,
@@ -388,7 +389,7 @@ def ReverseOp : Gradient_Op<"reverse",
388389
}
389390

390391
def ReturnOp : Gradient_Op<"return",
391-
[Terminator, ParentOneOf<["ForwardOp", "ReverseOp"]>]> {
392+
[ReturnLike, Terminator, ParentOneOf<["ForwardOp", "ReverseOp"]>]> {
392393

393394
let summary = "Return tapes or nothing";
394395

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright 2024-2025 Xanadu Quantum Technologies Inc.
2+
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
using namespace mlir;
18+
19+
namespace catalyst {
20+
21+
namespace gradient {
22+
23+
void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry &registry);
24+
25+
}
26+
27+
} // namespace catalyst

mlir/include/Gradient/Transforms/Passes.td

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,6 @@
1717

1818
include "mlir/Pass/PassBase.td"
1919

20-
def GradientBufferizationPass : Pass<"gradient-bufferize"> {
21-
let summary = "Bufferize tensors in quantum operations.";
22-
23-
let dependentDialects = [
24-
"bufferization::BufferizationDialect",
25-
"memref::MemRefDialect",
26-
"index::IndexDialect"
27-
];
28-
29-
let constructor = "catalyst::createGradientBufferizationPass()";
30-
}
31-
3220
def GradientPreprocessingPass : Pass<"gradient-preprocess"> {
3321
let summary = "Insert Func.CallOp for ForwardOp and ReverseOp.";
3422

mlir/include/Gradient/Transforms/Patterns.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
namespace catalyst {
2222
namespace gradient {
2323

24-
void populateBufferizationPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &);
2524
void populatePreprocessingPatterns(mlir::RewritePatternSet &);
2625
void populatePostprocessingPatterns(mlir::RewritePatternSet &);
2726
void populateLoweringPatterns(mlir::RewritePatternSet &);

mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h"
2323

2424
using namespace mlir;
25-
using namespace mlir::bufferization;
2625
using namespace catalyst;
2726

2827
/**

mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ void catalyst::registerAllCatalystPasses()
3737
mlir::registerPass(catalyst::createDisentangleSWAPPass);
3838
mlir::registerPass(catalyst::createEmitCatalystPyInterfacePass);
3939
mlir::registerPass(catalyst::createGEPInboundsPass);
40-
mlir::registerPass(catalyst::createGradientBufferizationPass);
4140
mlir::registerPass(catalyst::createGradientConversionPass);
4241
mlir::registerPass(catalyst::createGradientPreprocessingPass);
4342
mlir::registerPass(catalyst::createGradientPostprocessingPass);

mlir/lib/Driver/CompilerDriver.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
#include "Driver/Support.h"
6565
#include "Gradient/IR/GradientDialect.h"
6666
#include "Gradient/IR/GradientInterfaces.h"
67+
#include "Gradient/Transforms/BufferizableOpInterfaceImpl.h"
6768
#include "Gradient/Transforms/Passes.h"
6869
#include "Ion/IR/IonDialect.h"
6970
#include "MBQC/IR/MBQCDialect.h"
@@ -966,6 +967,7 @@ int QuantumDriverMainFromCL(int argc, char **argv)
966967

967968
// Register bufferization interfaces
968969
catalyst::registerBufferizableOpInterfaceExternalModels(registry);
970+
catalyst::gradient::registerBufferizableOpInterfaceExternalModels(registry);
969971
catalyst::quantum::registerBufferizableOpInterfaceExternalModels(registry);
970972

971973
// Register and parse command line options.

0 commit comments

Comments
 (0)