Skip to content

Commit

Permalink
pnnx convert batchnorm function (Tencent#5097)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Oct 23, 2023
1 parent 54ab805 commit 0c2a4a2
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions tools/pnnx/src/pass_level2/F_batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,38 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_batch_norm, 10)

class F_batch_norm_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
9 10
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 running_mean
pnnx.Input input_2 0 1 running_var
pnnx.Input input_3 0 1 weight
pnnx.Input input_4 0 1 bias
prim::Constant op_0 0 1 momentum value=*
prim::Constant op_1 0 1 eps value=%eps
aten::_native_batch_norm_legit_no_training op_2 7 3 input weight bias running_mean running_var momentum eps out save_mean save_invstd
pnnx.Output output 3 0 out save_mean save_invstd
)PNNXIR";
}

const char* type_str() const
{
return "F.batch_norm";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
GraphRewriterPass::write(op, captured_params, captured_attrs);

op->outputs.resize(1);
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_batch_norm_1, 10)

} // namespace pnnx

0 comments on commit 0c2a4a2

Please sign in to comment.