diff --git a/tools/pnnx/src/pass_level2/F_batch_norm.cpp b/tools/pnnx/src/pass_level2/F_batch_norm.cpp index e922e458a06..fb73a497c5f 100644 --- a/tools/pnnx/src/pass_level2/F_batch_norm.cpp +++ b/tools/pnnx/src/pass_level2/F_batch_norm.cpp @@ -45,4 +45,38 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_batch_norm, 10) +class F_batch_norm_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 10 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 running_mean +pnnx.Input input_2 0 1 running_var +pnnx.Input input_3 0 1 weight +pnnx.Input input_4 0 1 bias +prim::Constant op_0 0 1 momentum value=* +prim::Constant op_1 0 1 eps value=%eps +aten::_native_batch_norm_legit_no_training op_2 7 3 input weight bias running_mean running_var momentum eps out save_mean save_invstd +pnnx.Output output 3 0 out save_mean save_invstd +)PNNXIR"; + } + + const char* type_str() const + { + return "F.batch_norm"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + GraphRewriterPass::write(op, captured_params, captured_attrs); + + op->outputs.resize(1); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_batch_norm_1, 10) + } // namespace pnnx