Skip to content

Commit dd967eb

Browse files
renxidaXida Ren
andauthored
[ONNX] Support onnx.LSTM (#2969)
This PR only performs a lit test. In lieu of an e2e test, nod-ai/SHARK-TestSuite#142 makede sure that the lowering works & the numbers check out. Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
1 parent 1d6e4c3 commit dd967eb

File tree

9 files changed

+569
-2
lines changed

9 files changed

+569
-2
lines changed

include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,25 @@ struct OpBinder {
191191
return failure();
192192
}
193193

194+
ParseResult stringArrayAttr(llvm::SmallVector<std::string> &values,
195+
StringRef nameSuffix) {
196+
SmallString<64> name("torch.onnx.");
197+
name.append(nameSuffix);
198+
auto attr = op->getAttr(name);
199+
if (!attr)
200+
return success();
201+
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
202+
for (auto element : arrayAttr) {
203+
StringAttr stringAttr = element.dyn_cast<StringAttr>();
204+
if (!stringAttr)
205+
return failure();
206+
values.push_back(stringAttr.getValue().str());
207+
}
208+
return success();
209+
}
210+
return failure();
211+
}
212+
194213
ParseResult denseElementsAttr(ElementsAttr elementsattr,
195214
StringRef nameSuffix) {
196215
SmallString<64> name("torch.onnx.");

include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
#ifndef TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
1111
#define TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
1212

13+
#include "mlir/IR/ImplicitLocOpBuilder.h"
1314
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
15+
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
16+
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
1417

1518
namespace mlir::torch::onnx_c {
1619

@@ -20,6 +23,9 @@ Value createConstantIntList(OpBinder binder,
2023

2124
Type getQTorchTypeFromTorchIntType(Type ty);
2225

26+
LogicalResult OnnxLstmExpander(OpBinder binder,
27+
ConversionPatternRewriter &rewriter);
28+
2329
bool areAllElementsDistinct(SmallVector<int64_t> array);
2430

2531
} // namespace mlir::torch::onnx_c

lib/Conversion/TorchOnnxToTorch/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_conversion_library(TorchMLIRTorchOnnxToTorch
22
DefaultDomainAtoF.cpp
33
DefaultDomainGtoP.cpp
44
DefaultDomainQtoZ.cpp
5+
OnnxLstmExpander.cpp
56
Passes.cpp
67
Patterns.cpp
78
TorchOnnxToTorch.cpp

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
195195
binder.op, resultType, operand);
196196
return success();
197197
});
198+
patterns.onOp("LSTM", 1, onnx_c::OnnxLstmExpander);
198199
patterns.onOp(
199200
"LogSoftmax", 13,
200201
[](OpBinder binder, ConversionPatternRewriter &rewriter) {

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
538538
return success();
539539
});
540540
patterns.onOp(
541-
"Squeeze", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
541+
"Squeeze", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
542542
Torch::ValueTensorType resultType;
543543
Value data;
544544
Value axes;

0 commit comments

Comments
 (0)