Skip to content

Conversation

YuchiWen
Copy link

No description provided.

@YuchiWen YuchiWen force-pushed the fuse_bn_into_conv_transpose branch from 44658d1 to 6e5ac70 Compare November 22, 2022 04:01
@daquexian
Copy link
Member

Sorry for the late response. Could you please add some tests for the fusion? You can follow the conv-bn fusion https://github.com/onnx/optimizer/blob/master/onnxoptimizer/test/optimizer_test.py#L3024

@YuchiWen YuchiWen force-pushed the fuse_bn_into_conv_transpose branch from 6e5ac70 to 5d4d388 Compare March 6, 2023 10:52
@YuchiWen
Copy link
Author

YuchiWen commented Mar 6, 2023

Sorry for the late response. Could you please add some tests for the fusion? You can follow the conv-bn fusion https://github.com/onnx/optimizer/blob/master/onnxoptimizer/test/optimizer_test.py#L3024

@daquexian Done, please review.

Signed-off-by: wenyuchi.wyc <wenyuchi.wyc@alibaba-inc.com>
@YuchiWen YuchiWen force-pushed the fuse_bn_into_conv_transpose branch from 5d4d388 to ff1229e Compare March 6, 2023 11:31
@huangzhicong3
Copy link

Hello, i try to used this commit to fuse the bn layer and convtranspose layer in my model and find some bugs:
The error message is:
passes/fuse_bn_into_conv.h:71: modify_conv: Assertion conv_W.sizes().size() > 2 && conv_W.sizes()[0] == C failed.

From the doc of onnx website (https://onnx.ai/onnx/operators/onnx__ConvTranspose.html), the shape of weight array of convtranspose is (Cin, Cout, K, K), which is different to normal Conv layer (Cout, Cin, K, K).

@huangzhicong3
Copy link

huangzhicong3 commented May 30, 2024

Hi, i would like to share my codes for fusing convtranspose and bn. It has been tested on my own model. I hope it will help others who have the same issue.

import numpy as np
import onnx
import sclblonnx as so

model = onnx.load('../onnx/models/backbone_clean.onnx')

all_initializer = model.graph.initializer
all_node = model.graph.node
ConvTranspose_list = []
BatchNormalization_list = []
for i, node in enumerate(all_node):
    # search convtranspose and batchnormalization
    if node.op_type == "ConvTranspose":
        # print(i, node.name, node.op_type,  node.input, node.output)
        ConvTranspose_list.append(node)
    if node.op_type == "BatchNormalization":
        # print(i, node.name, node.op_type,  node.input, node.output)
        BatchNormalization_list.append(node)

valid_ConvTranspose_list = []
for node in ConvTranspose_list:
    output = node.output
    for bn_node in BatchNormalization_list:
        bn_inputs = bn_node.input
        if output[0] in bn_inputs:
            valid_ConvTranspose_list.append({"conv": node, "bn": bn_node})
            continue

# print(valid_ConvTranspose_list)
param_dict = {}
for node in valid_ConvTranspose_list:
    conv = node["conv"]
    bn = node["bn"]
    # find params
    param_name = list(conv.input) + list(bn.input)
    for i, initializer in enumerate(all_initializer):
        if initializer.name in param_name:
            param_dict[initializer.name] = onnx.numpy_helper.to_array(initializer)
# print(param_dict)
for node in valid_ConvTranspose_list:
    conv = node["conv"]
    bn = node["bn"]

    bn_eps = bn.attribute[0].f
    bn_mom = bn.attribute[1].f

    bn_w = param_dict[bn.input[1]]  # [Cout, ]
    bn_b = param_dict[bn.input[2]]  # [Cout, ]
    bn_mean = param_dict[bn.input[3]]  # [Cout, ]
    bn_var = param_dict[bn.input[4]]  # [Cout, ]

    conv_w = param_dict[conv.input[1]]  # [Cin, Cout, H, W]
    if len(conv.input) > 2:
        conv_b = param_dict[conv.input[2]]
    else:
        conv_b = np.zeros_like(bn_b)  # [Cout, ]
    conv_w_tran = conv_w.transpose(1, 0, 2, 3)

    Cout = conv_w_tran.shape[0]
    conv_w_reshape = conv_w_tran.reshape([Cout, -1])
    w_bn = np.diag(bn_w / (np.sqrt(bn_eps + bn_var)))
    new_conv_w = np.matmul(w_bn, conv_w_reshape).reshape(conv_w_tran.shape).transpose(1, 0, 2, 3)
    bn_b_tmp = bn_b - (np.multiply(bn_w, bn_mean) / (np.sqrt(bn_eps + bn_var)))
    new_conv_b = np.matmul(bn_w, conv_b) + bn_b_tmp

    new_node = onnx.helper.make_node(
        name=conv.name+'_bn',
        op_type="ConvTranspose",
        inputs=[conv.input[0], conv.name+'_bn.weights', conv.name+'_bn.bias'],
        outputs=[bn.output[0]],
        dilations=conv.attribute[0].ints,
        group=conv.attribute[1].i,
        kernel_shape=conv.attribute[2].ints,
        pads=conv.attribute[3].ints,
        strides=conv.attribute[4].ints
    )
    initializer_w = onnx.helper.make_tensor(
        name=conv.name+'_bn.weights',
        data_type=onnx.helper.TensorProto.DataType.FLOAT,
        dims=new_conv_w.shape,
        vals=new_conv_w.tobytes(),
        raw=True
    )
    initializer_b = onnx.helper.make_tensor(
        name=conv.name+'_bn.bias',
        data_type=onnx.helper.TensorProto.DataType.FLOAT,
        dims=new_conv_b.shape,
        vals=new_conv_b.tobytes(),
        raw=True
    )

    model.graph.initializer.append(initializer_w)
    model.graph.initializer.append(initializer_b)

    # insert node
    for i, node in enumerate(all_node):
        if conv.name == node.name:
            model.graph.node.insert(i, new_node)
            break
    # clean node
    model.graph.node.remove(conv)
    model.graph.node.remove(bn)

onnx.checker.check_model(model)
onnx.save(model, '../onnx/models/backbone_fuse.onnx')

graph = so.graph_from_file('../onnx/models/backbone_fuse.onnx')
graph = so.clean(graph)
so.check(graph)
so.graph_to_file(graph, '../onnx/models/backbone_fuse.onnx')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants