-
Notifications
You must be signed in to change notification settings - Fork 608
[ONNX] Support onnx.LSTM #2969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Support onnx.LSTM #2969
Conversation
|
Test case draft: lstm.onnx.mlir module {
func.func @lstm(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.hidden_size = 3 : si64} : (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>)
return %0#0, %0#1, %0#2 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>
}
}current Torch IR```mlir module { func.func @lstm(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none %int0 = torch.constant.int 0 %int0_0 = torch.constant.int 0 %0 = torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1,12,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[12,4],f32> %int0_1 = torch.constant.int 0 %int0_2 = torch.constant.int 0 %1 = torch.aten.select.int %arg2, %int0_1, %int0_2 : !torch.vtensor<[1,12,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[12,3],f32> %int0_3 = torch.constant.int 0 %int0_4 = torch.constant.int 0 %2 = torch.aten.select.int %arg3, %int0_3, %int0_4 : !torch.vtensor<[1,24],f32>, !torch.int, !torch.int -> !torch.vtensor<[24],f32> %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 %none_5 = torch.constant.none %int0_6 = torch.constant.int 0 %int1_7 = torch.constant.int 1 %3 = torch.prim.ListConstruct %int1, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list %int6 = torch.constant.int 6 %4 = torch.aten.zeros %3, %int6, %none_5, %none_5, %none_5 : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2,3],f32> %5 = torch.aten.zeros %3, %int6, %none_5, %none_5, %none_5 : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2,3],f32> %int0_8 = torch.constant.int 0 %int0_9 = torch.constant.int 0 %6 = torch.aten.select.int %4, %int0_8, %int0_9 : !torch.vtensor<[1,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> %int0_10 = torch.constant.int 0 %int0_11 = torch.constant.int 0 %7 = torch.aten.select.int %5, %int0_10, %int0_11 : !torch.vtensor<[1,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> %int12 = torch.constant.int 12 %int24 = torch.constant.int 24 %8 = torch.aten.slice.Tensor %2, %int0_6, %int0_6, %int12, %int1_7 : !torch.vtensor<[24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[12],f32> %9 = torch.aten.slice.Tensor %2, %int0_6, %int12, %int24, %int1_7 : !torch.vtensor<[24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[12],f32> %10 = torch.prim.ListConstruct : () -> !torch.list> %int15 = torch.constant.int 15 %true = torch.constant.bool true %int0_12 = torch.constant.int 0 %int1_13 = torch.constant.int 1 %11:2 = torch.prim.Loop %int15, %true, init(%6, %7) { ^bb0(%arg4: !torch.int, %arg5: !torch.vtensor<[2,3],f32>, %arg6: !torch.vtensor<[2,3],f32>): %16 = torch.aten.select.int %arg0, %int0_12, %arg4 : !torch.vtensor<[15,2,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,4],f32> %int0_14 = torch.constant.int 0 %int1_15 = torch.constant.int 1 %int4 = torch.constant.int 4 %17 = torch.prim.ListConstruct %int1_15, %int4 : (!torch.int, !torch.int) -> !torch.list %18 = torch.aten.tile %16, %17 : !torch.vtensor<[2,4],f32>, !torch.list -> !torch.vtensor<[2,16],f32> %19 = torch.aten.tile %arg5, %17 : !torch.vtensor<[2,3],f32>, !torch.list -> !torch.vtensor<[2,12],f32> %20 = torch.aten.linear %18, %0, %8 : !torch.vtensor<[2,16],f32>, !torch.vtensor<[12,4],f32>, !torch.vtensor<[12],f32> -> !torch.vtensor<[2,12],f32> %21 = torch.aten.linear %19, %1, %9 : !torch.vtensor<[2,12],f32>, !torch.vtensor<[12,3],f32>, !torch.vtensor<[12],f32> -> !torch.vtensor<[2,12],f32> %22 = torch.aten.add.Tensor %20, %21, %int1_15 : !torch.vtensor<[2,12],f32>, !torch.vtensor<[2,12],f32>, !torch.int -> !torch.vtensor<[2,12],f32> %int3_16 = torch.constant.int 3 %int6_17 = torch.constant.int 6 %int9 = torch.constant.int 9 %int12_18 = torch.constant.int 12 %23 = torch.aten.slice.Tensor %22, %int1_15, %int0_14, %int9, %int1_15 : !torch.vtensor<[2,12],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,9],f32> %24 = torch.aten.slice.Tensor %22, %int1_15, %int9, %int12_18, %int1_15 : !torch.vtensor<[2,12],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> %25 = torch.aten.sigmoid %23 : !torch.vtensor<[2,9],f32> -> !torch.vtensor<[2,9],f32> %26 = torch.aten.tanh %24 : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> %27 = torch.aten.slice.Tensor %25, %int1_15, %int0_14, %int3_16, %int1_15 : !torch.vtensor<[2,9],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> %28 = torch.aten.slice.Tensor %25, %int1_15, %int3_16, %int6_17, %int1_15 : !torch.vtensor<[2,9],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> %29 = torch.aten.slice.Tensor %25, %int1_15, %int6_17, %int9, %int1_15 : !torch.vtensor<[2,9],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> %30 = torch.aten.mul.Tensor %29, %arg6 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> %31 = torch.aten.mul.Tensor %27, %26 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> %32 = torch.aten.add.Tensor %30, %31, %int1_15 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32> %33 = torch.aten.tanh %32 : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> %34 = torch.aten.mul.Tensor %28, %33 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> %35 = torch.aten.append.t %10, %34 : !torch.list>, !torch.vtensor<[2,3],f32> -> !torch.list> %36 = torch.aten.add.int %arg4, %int1_13 : !torch.int, !torch.int -> !torch.int torch.prim.Loop.condition %true, iter(%34, %32 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) } : (!torch.int, !torch.bool, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) %12 = torch.aten.unsqueeze %11#0, %int0_6 : !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[1,2,3],f32> %13 = torch.aten.unsqueeze %11#1, %int0_6 : !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[1,2,3],f32> %14 = torch.aten.stack %10, %int0_6 : !torch.list>, !torch.int -> !torch.vtensor<[15,2,3],f32> %15 = torch.aten.unsqueeze %14, %int1_7 : !torch.vtensor<[15,2,3],f32>, !torch.int -> !torch.vtensor<[15,1,2,3],f32> return %15, %12, %13 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32> } } |
|
Currently testing with: /home/azureuser/torch-mlir/build/bin/torch-mlir-opt -pass-pipeline='builtin.module(func.func(convert-torch-onnx-to-torch))' lstm.onnx.mlir -o lstm.torch.mlir
~/torch-mlir/build/bin/torch-mlir-opt --mlir-print-debuginfo --mlir-elide-elementsattrs-if-larger=16 --mlir-print-stacktrace-on-diagnostic --mlir-disable-threading --mlir-print-ir-after-failure --mlir-print-ir-module-scope -pass-pipeline='builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)' ./lstm.torch.mlir -o ./lstm.linalg.mlir 2>&1 | tee torchtolinalg.log |
|
encountered a bunch of issues earlier, fixed thanks to Rob. Currently stuck on: ./lstm.torch.mlir:61:13: error: 'torch_c.to_builtin_tensor' op operand #0 must be Multi-dimensional array modeling Torch's Tensor type, but got 'tensor<2x3xf32>' %30 = torch.aten.mul.Tensor %29, %arg6 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> ^./lstm.torch.mlir:61:13: note: diagnostic emitted with trace: #0 0x000055b8b5e22e2d llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /home/azureuser/torch-mlir/externals/llvm-project/llvm/lib/Support/Unix/Signals.inc:723:11 #1 0x000055b8b5bf959e emitDiag(mlir::Location, mlir::DiagnosticSeverity, llvm::Twine const&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Diagnostics.cpp:319:5 #2 0x000055b8b5bf94c5 mlir::emitError(mlir::Location, llvm::Twine const&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Diagnostics.cpp:330:10 #3 0x000055b8b5c9f888 mlir::Operation::emitError(llvm::Twine const&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Operation.cpp:269:29 #4 0x000055b8b5c9f359 mlir::Operation::emitOpError(llvm::Twine const&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Operation.cpp:672:22 #5 0x000055b8b2d9c635 mlir::torch::TorchConversion::__mlir_ods_local_type_constraint_TorchConversionOps1(mlir::Operation*, mlir::Type, llvm::StringRef, unsigned int) /home/azureuser/torch-mlir/build/tools/torch-mlir/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc:52:39 #6 0x000055b8b2da4a46 mlir::torch::TorchConversion::ToBuiltinTensorOp::verifyInvariantsImpl() /home/azureuser/torch-mlir/build/tools/torch-mlir/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc:1499:26 #7 0x000055b8b2d93cc2 mlir::OpTrait::OpInvariants::verifyTrait(mlir::Operation*) /home/azureuser/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/OpDefinition.h:432:35 #8 0x000055b8b2d93b05 std::enable_if>::value, mlir::LogicalResult>::type mlir::op_definition_impl::verifyTrait>(mlir::Operation*) /home/azureuser/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/OpDefinition.h:1620:10 #9 0x000055b8b2d938f7 mlir::LogicalResult mlir::op_definition_impl::verifyTraits, mlir::OpTrait::OneResult, mlir::OpTrait::OneTypedResult::Impl, mlir::OpTrait::ZeroSuccessors, mlir::OpTrait::OneOperand, mlir::OpTrait::OpInvariants, mlir::InferTypeOpInterface::Trait, mlir::ConditionallySpeculatable::Trait, mlir::OpTrait::AlwaysSpeculatableImplTrait, mlir::MemoryEffectOpInterface::Trait>(mlir::Operation*) /home/azureuser/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/OpDefinition.h:1631:29 #10 0x000055b8b2d937b5 mlir::Op::Impl, mlir::OpTrait::ZeroSuccessors, mlir::OpTrait::OneOperand, mlir::OpTrait::OpInvariants, mlir::InferTypeOpInterface::Trait, mlir::ConditionallySpeculatable::Trait, mlir::OpTrait::AlwaysSpeculatableImplTrait, mlir::MemoryEffectOpInterface::Trait>::verifyInvariants(mlir::Operation*) /home/azureuser/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/OpDefinition.h:2012:16 #11 0x000055b8b20126f5 mlir::LogicalResult llvm::detail::UniqueFunctionBase::CallImpl(void*, mlir::Operation*) /home/azureuser/torch-mlir/externals/llvm-project/llvm/include/llvm/ADT/FunctionExtras.h:221:12 #12 0x000055b8b2011e57 llvm::unique_function::operator()(mlir::Operation*) const /home/azureuser/torch-mlir/externals/llvm-project/llvm/include/llvm/ADT/FunctionExtras.h:411:12 #13 0x000055b8b2d92466 mlir::RegisteredOperationName::Model::verifyInvariants(mlir::Operation*) /home/azureuser/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/OperationSupport.h:558:14 #14 0x000055b8b5ce8ba6 mlir::OperationName::verifyInvariants(mlir::Operation*) const /home/azureuser/torch-mlir/externals/llvm-project/mlir/include/mlir/IR/OperationSupport.h:317:23 #15 0x000055b8b5ce5b8d (anonymous namespace)::OperationVerifier::verifyOnEntrance(mlir::Operation&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Verifier.cpp:179:48 #16 0x000055b8b5ce58e0 _ZZN12_GLOBAL__N_117OperationVerifier15verifyOperationERN4mlir9OperationEENK3$_2clIS2_EEDaPT_ /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Verifier.cpp:293:45 #17 0x000055b8b5ce47f7 _ZZN12_GLOBAL__N_117OperationVerifier15verifyOperationERN4mlir9OperationEENK3$_1clIZNS0_15verifyOperationES3_E3$_2EEDaOT_N4llvm12PointerUnionIJPS2_PNS1_5BlockEEEE /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Verifier.cpp:277:16 #18 0x000055b8b5ce401f (anonymous namespace)::OperationVerifier::verifyOperation(mlir::Operation&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Verifier.cpp:292:16 #19 0x000055b8b5ce3de1 (anonymous namespace)::OperationVerifier::verifyOpAndDominance(mlir::Operation&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Verifier.cpp:85:14 #20 0x000055b8b5ce3d92 mlir::verify(mlir::Operation*, bool) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Verifier.cpp:423:19 #21 0x000055b8b4826985 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:548:27 #22 0x000055b8b4826e74 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:585:16 #23 0x000055b8b4827cce mlir::detail::OpToOpPassAdaptor::runOnOperationImpl(bool) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:726:20 #24 0x000055b8b482757d mlir::detail::OpToOpPassAdaptor::runOnOperation(bool) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:709:1 #25 0x000055b8b482b186 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1::operator()() const /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:517:11 #26 0x000055b8b482b135 void llvm::function_ref::callback_fn(long) /home/azureuser/torch-mlir/externals/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45:5 #27 0x000055b8b1efc0e9 llvm::function_ref::operator()() const /home/azureuser/torch-mlir/externals/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:5 #28 0x000055b8b482ded5 void mlir::MLIRContext::executeAction(llvm::function_ref, llvm::ArrayRef, mlir::Pass&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/include/mlir/IR/MLIRContext.h:276:3 #29 0x000055b8b48268f3 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:525:17 #30 0x000055b8b4826e74 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:585:16 #31 0x000055b8b48288b8 mlir::PassManager::runPasses(mlir::Operation*, mlir::AnalysisManager) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:896:10 #32 0x000055b8b48287e2 mlir::PassManager::run(mlir::Operation*) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:876:60 #33 0x000055b8b1e7ee72 performActions(llvm::raw_ostream&, std::shared_ptr const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:396:17 #34 0x000055b8b1e7eaa8 processBuffer(llvm::raw_ostream&, std::unique_ptr>, mlir::MlirOptMainConfig const&, mlir::DialectRegistry&, llvm::ThreadPool*) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:461:12 #35 0x000055b8b1e7e88c mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_0::operator()(std::unique_ptr>, llvm::raw_ostream&) const /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:532:12 #36 0x000055b8b1e7e826 mlir::LogicalResult llvm::function_ref>, llvm::raw_ostream&)>::callback_fn>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_0>(long, std::unique_ptr>, llvm::raw_ostream&) /home/azureuser/torch-mlir/externals/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12 #37 0x000055b8b5d0ea22 llvm::function_ref>, llvm::raw_ostream&)>::operator()(std::unique_ptr>, llvm::raw_ostream&) const /home/azureuser/torch-mlir/externals/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12 #38 0x000055b8b5d0e03d mlir::splitAndProcessBuffer(std::unique_ptr>, llvm::function_ref>, llvm::raw_ostream&)>, llvm::raw_ostream&, bool, bool) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Support/ToolUtilities.cpp:28:12 #39 0x000055b8b1e7b75b mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:535:10 #40 0x000055b8b1e7b9f5 mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:570:14 #41 0x000055b8b1e7bbc8 mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:586:10 #42 0x000055b8b1e770e5 main /home/azureuser/torch-mlir/tools/torch-mlir-opt/torch-mlir-opt.cpp:43:33 #43 0x00007faabdc23a90 (/lib/x86_64-linux-gnu/libc.so.6+0x23a90) #44 0x00007faabdc23b49 __libc_start_main (/lib/x86_64-linux-gnu/libc.so.6+0x23b49) #45 0x000055b8b1e76f75 _start (/home/azureuser/torch-mlir/build/bin/torch-mlir-opt+0x234f75)./lstm.torch.mlir:61:13: note: see current operation: %45 = "torch_c.to_builtin_tensor"(%arg6) : (tensor<2x3xf32>) -> tensor<2x3xf32> loc("./lstm.torch.mlir":61:13) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will need a numerical validation test. Without that its impossible to say that the implementation is guaranteed to be correct. You will have to disable for torch.
|
Hi @renxida, is there anything remaining for this PR to be merged? |
|
@vivekkhandelwal1 yup. need a numeric test. but it's hitting some issues lowering past LinAlg. Specifically:
|
include/torch-mlir/Conversion/TorchOnnxToTorch/OnnxLstmExpander.h
Outdated
Show resolved
Hide resolved
|
@renxida, you have done the LLVM bump in this PR. Do we need that for the changes in this PR? If not, can we do it in a separate PR? Also, the CI seems to be broken might be because of the LLVM bump. |
no but i was keeping this up to date with main because github is telling me i have merge conflicts. is there a better way i should have done it? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some small nits and cleanup required but should be almost done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, please include a comment about testing downstream with a link to a PR that merges in that work.
No description provided.