Skip to content

Commit

Permalink
Merge branch 'master' into ci-android-riscv
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Aug 30, 2024
2 parents ddb7090 + 5e2d56d commit 454e647
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 2 deletions.
26 changes: 26 additions & 0 deletions tools/pnnx/src/pass_level2/F_hardswish.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,4 +343,30 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardswish_onnx_2, 9)

class F_hardswish_onnx_3 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
8 7
pnnx.Input input 0 1 input
prim::Constant op_0 0 1 v3 value=3
aten::add op_1 2 1 input v3 a
aten::clamp op_2 1 1 a b max=6 min=0
aten::mul op_3 2 1 input b c
prim::Constant op_4 0 1 v6 value=6
aten::div op_5 2 1 c v6 out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.hardswish";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardswish_onnx_3, 9)

} // namespace pnnx
66 changes: 64 additions & 2 deletions tools/pnnx/src/pass_level2/F_linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class F_linear_onnx : public GraphRewriterPass
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 weight
pnnx.Input input_2 0 1 bias
Gemm op_0 3 1 input weight bias out alpha=1.000000e+00 beta=1.000000e+00 transB=1
Gemm gemm 3 1 input weight bias out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand All @@ -138,6 +138,39 @@ pnnx.Output output 1 0 out
{
return "F.linear";
}

bool match(const std::map<std::string, const Operator*>& matched_operators, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& /*captured_attrs*/) const
{
if (captured_params.find("gemm.alpha") != captured_params.end())
{
if (captured_params.at("gemm.alpha").type != 3 || captured_params.at("gemm.alpha").f != 1.f)
return false;
}

if (captured_params.find("gemm.beta") != captured_params.end())
{
if (captured_params.at("gemm.beta").type != 3 || captured_params.at("gemm.beta").f != 1.f)
return false;
}

if (captured_params.find("gemm.transA") != captured_params.end())
{
if (captured_params.at("gemm.transA").type != 2 || captured_params.at("gemm.transA").i != 0)
return false;
}

if (captured_params.find("gemm.transB") == captured_params.end())
return false;

if (captured_params.at("gemm.transB").type != 2 || captured_params.at("gemm.transB").i != 1)
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/) const
{
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_linear_onnx, 10)
Expand All @@ -152,7 +185,7 @@ class F_linear_onnx_1 : public GraphRewriterPass
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 bias
pnnx.Attribute weight 0 1 weight @data=(%in_features,%out_features)f32
Gemm gemm 3 1 input weight bias out alpha=1.000000e+00 beta=1.000000e+00
Gemm gemm 3 1 input weight bias out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand All @@ -169,6 +202,35 @@ pnnx.Output output 1 0 out
)PNNXIR";
}

bool match(const std::map<std::string, const Operator*>& matched_operators, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& /*captured_attrs*/) const
{
if (captured_params.find("gemm.alpha") != captured_params.end())
{
if (captured_params.at("gemm.alpha").type != 3 || captured_params.at("gemm.alpha").f != 1.f)
return false;
}

if (captured_params.find("gemm.beta") != captured_params.end())
{
if (captured_params.at("gemm.beta").type != 3 || captured_params.at("gemm.beta").f != 1.f)
return false;
}

if (captured_params.find("gemm.transA") != captured_params.end())
{
if (captured_params.at("gemm.transA").type != 2 || captured_params.at("gemm.transA").i != 0)
return false;
}

if (captured_params.find("gemm.transB") != captured_params.end())
{
if (captured_params.at("gemm.transB").type != 2 || captured_params.at("gemm.transB").i != 0)
return false;
}

return true;
}

void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const int in_features = captured_params.at("in_features").i;
Expand Down
53 changes: 53 additions & 0 deletions tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,57 @@ pnnx.Output output 1 0 out
}
};

class fuse_multiheadattention_pass_1_1_1 : public fuse_multiheadattention_pass_sameqkv
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
19 18
pnnx.Input input 0 1 input
nn.Linear op_0 1 1 input 256 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 input 257 bias=%kbias in_features=%embed_dim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 input 260 bias=%vbias in_features=%embed_dim out_features=%embed_dim @bias @weight
Tensor.view op_3 1 1 256 263 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.view op_4 1 1 257 258 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.view op_5 1 1 260 261 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.permute op_6 1 1 263 264 dims=(0,2,1,3)
Tensor.permute op_7 1 1 258 259 dims=(0,2,1,3)
Tensor.permute op_8 1 1 261 262 dims=(0,2,1,3)
torch.transpose op_9 1 1 259 265 dim0=-1 dim1=-2
torch.matmul op_10 2 1 264 265 266
pnnx.Expression op_11 1 1 266 267 expr=div(@0,%sqrt_feat_per_head)
F.softmax softmax 1 1 267 268 dim=%softmax_dim
torch.matmul op_13 2 1 268 262 269
Tensor.permute op_14 1 1 269 270 dims=(0,2,1,3)
Tensor.reshape op_15 1 1 270 271 shape=(%batch,%size,%embed_dim)
nn.Linear out_proj 1 1 271 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}

bool match(const std::map<std::string, const Operator*>& matched_operators, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& /*captured_attrs*/) const
{
const int embed_dim = captured_params.at("embed_dim").i;
const int num_heads = captured_params.at("num_heads").i;
const int feat_per_head = captured_params.at("feat_per_head").i;
const float sqrt_feat_per_head = captured_params.at("sqrt_feat_per_head").f;
const int softmax_dim = captured_params.at("softmax_dim").i;

if (embed_dim != num_heads * feat_per_head)
return false;

if (!NearlyEqual(sqrt_feat_per_head, sqrt(feat_per_head), 0.001))
return false;

int softmax_input_rank = (int)matched_operators.at("softmax")->inputs[0]->shape.size();
if (softmax_dim != -1 && softmax_dim != softmax_input_rank - 1)
return false;

return true;
}
};

class fuse_multiheadattention_pass_1_2 : public fuse_multiheadattention_pass_qkv
{
public:
Expand Down Expand Up @@ -2082,6 +2133,7 @@ void fuse_multiheadattention(Graph& graph)
fuse_multiheadattention_pass_q_samekv d;
fuse_multiheadattention_pass_1 b1;
fuse_multiheadattention_pass_1_1 b11;
fuse_multiheadattention_pass_1_1_1 b111;
fuse_multiheadattention_pass_1_2 b12;
fuse_multiheadattention_pass_2 c1;
fuse_multiheadattention_pass_3 d1;
Expand Down Expand Up @@ -2122,6 +2174,7 @@ void fuse_multiheadattention(Graph& graph)
pnnx_graph_rewrite(graph, &d, opindex);
pnnx_graph_rewrite(graph, &b1, opindex);
pnnx_graph_rewrite(graph, &b11, opindex);
pnnx_graph_rewrite(graph, &b111, opindex);
pnnx_graph_rewrite(graph, &b12, opindex);
pnnx_graph_rewrite(graph, &c1, opindex);
pnnx_graph_rewrite(graph, &d1, opindex);
Expand Down

0 comments on commit 454e647

Please sign in to comment.