Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OneShotChannelPruner results in the miss of some operators #272

Open
Xz-Alan opened this issue Dec 21, 2023 · 4 comments
Open

OneShotChannelPruner results in the miss of some operators #272

Xz-Alan opened this issue Dec 21, 2023 · 4 comments
Labels
bug Something isn't working

Comments

@Xz-Alan
Copy link

Xz-Alan commented Dec 21, 2023

Hello author, I find that graph tracking algorithms may cause some operator pruning to be lost if there are multiple parallel computations in the network.

import torch
import torch.nn as nn

from tinynn.prune.oneshot_pruner import OneShotChannelPruner


class SimpleNet(nn.Module):
    def __init__(self, embed_dim=32, num_blocks=3):
        super().__init__()

        self.head = nn.Conv2d(3, embed_dim, 3, 1, 1)

        self.attn1 = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
        self.attn2 = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
        self.propagate_fusion = nn.Conv2d(3*embed_dim, embed_dim, 3, 1, 1)

        self.conv1 = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
        self.conv2 = nn.Conv2d(2*embed_dim, 2*embed_dim, 3, 1, 1)
        self.conv3 = nn.Conv2d(2*embed_dim, 2*embed_dim, 3, 1, 1)

        self.down1 = nn.Conv2d(embed_dim, 2*embed_dim, 4, 2, 1)

        self.up1 = nn.ConvTranspose2d(2*embed_dim, embed_dim, 2, 2)

        self.tail = nn.Conv2d(2*embed_dim, embed_dim, 3, 1, 1)

    def my_fusion(self, feats):
        feat_base = feats[0]
        fusion_feat = torch.zeros_like(feat_base)
        for i in range(len(feats)):
            feat_current = feats[i]
            feat = [feat_base] + [feat_current] + [fusion_feat]
            feat = torch.cat(feat, dim=1)
            add_feat = self.propagate_fusion(feat)
            fusion_feat = fusion_feat + add_feat
        return fusion_feat

    def forward(self, inp):
        deep_feat = self.head(inp)

        t, c, h, w = deep_feat.shape

        ref_feat = deep_feat[0, :, :, :].unsqueeze(0)
        embedding_ref = self.attn1(ref_feat)
        embedding = self.attn2(deep_feat)

        corr_diff = []
        corr_l = []
        for i in range(t):
            emb_neighbor = embedding[i, :, :, :].unsqueeze(0)
            corr = torch.sum(emb_neighbor * embedding_ref, 1)
            corr_l.append(corr)
            if i == 0:
                continue
            else:
                corr_difference = torch.abs(corr_l[i] - corr_l[0])
                corr_diff.append(corr_difference)

        corr_prob = torch.sigmoid(torch.cat(corr_diff, dim=0))
        corr_prob = corr_prob.unsqueeze(1).expand(t-1, c, h, w)
        oth_feat = deep_feat[1:, :, :, :] * corr_prob

        feat_guided = [deep_feat[0, :, :, :].unsqueeze(0)]
        feat_guided += [oth_feat[i, :, :, :].unsqueeze(0) for i in range(0, t-1)]
        feat = self.my_fusion(feat_guided)

        feat1 = self.conv1(feat)
        down1 = self.down1(feat1)
        feat2 = self.conv2(down1)

        up1 = self.up1(feat2)
        cat1 = torch.cat([up1, feat1], 1)
        feat3 = self.conv3(cat1)

        feat_out = self.tail(feat3) + feat
        return feat_out


if __name__ == '__main__':
    inp = torch.randn(8, 3, 128, 128)
    model = SimpleNet(embed_dim=32)

    out = model(inp)
    print(f"input: {inp.shape}")
    print(f"output: {out.shape}")

    pruner = OneShotChannelPruner(model, inp, config={'sparsity': 1/8, 'metrics': 'l2_norm'})

    st_flops = pruner.calc_flops()
    pruner.prune()
    pruner.graph.generate_code(output_script_path='./test.py',
                               output_weight_path='./test.pth',
                               model_name="SimpleNet_prune")
    ed_flops = pruner.calc_flops()
    print(f"Pruning over, reduced FLOPS {100 * (st_flops - ed_flops) / st_flops:.2f}%  ({st_flops} -> {ed_flops})")

running above codes gives error below.

INFO (tinynn.graph.modifier) Start tracking tensor dimension changes...
INFO (tinynn.graph.modifier) Start dividing subgraphs according to tensor dependencies...
INFO (tinynn.graph.modifier) Start to eliminate dimension change conflicts...
INFO (tinynn.graph.modifier) Start generating new subgraphs without conflicts...
INFO (tinynn.prune.oneshot_pruner) Register a mask for each operator
INFO (tinynn.prune.oneshot_pruner) subgraph [head] compute over
INFO (tinynn.prune.oneshot_pruner) subgraph [attn2] compute over
INFO (tinynn.prune.oneshot_pruner) subgraph [propagate_fusion_7] compute over
INFO (tinynn.prune.oneshot_pruner) subgraph [conv1] compute over
INFO (tinynn.prune.oneshot_pruner) subgraph [conv3] compute over
INFO (tinynn.prune.oneshot_pruner) subgraph [tail] compute over
INFO (tinynn.prune.oneshot_pruner) Apply the mask of each operator
INFO (tinynn.graph.modifier) [CONV] head: output 32 -> 28
INFO (tinynn.graph.modifier) [CONV] head: bias 32 -> 28
INFO (tinynn.graph.modifier) [CONV] attn1: input 32 -> 28
INFO (tinynn.graph.modifier) [CONV] attn2: input 32 -> 28
INFO (tinynn.graph.modifier) [CONV] attn2: output 32 -> 28
INFO (tinynn.graph.modifier) [CONV] attn2: bias 32 -> 28
INFO (tinynn.graph.modifier) [CONV] propagate_fusion: input 96 -> 84
INFO (tinynn.graph.modifier) [CONV] propagate_fusion: output 32 -> 24
INFO (tinynn.graph.modifier) [CONV] propagate_fusion: bias 32 -> 24
INFO (tinynn.graph.modifier) [CONV] conv3: output 64 -> 56
INFO (tinynn.graph.modifier) [CONV] conv3: bias 64 -> 56
INFO (tinynn.graph.modifier) [CONV] tail: input 64 -> 56
INFO (tinynn.graph.modifier) [CONV] tail: output 32 -> 28
INFO (tinynn.graph.modifier) [CONV] tail: bias 32 -> 28
Traceback (most recent call last):
  File "d:\doc\code\tmp\tinynn\graph\tracer.py", line 3335, in trace
    new_graph.init()
  File "d:\doc\code\tmp\tinynn\graph\tracer.py", line 2033, in init
    self.module(*actual_input)
  File "D:\app\anaconda3\envs\alanosu\lib\site-packages\torch\nn\modules\module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "d:\doc\code\tmp\tmp.py", line 52, in forward
    corr = torch.sum(emb_neighbor * embedding_ref, 1)
  File "d:\doc\code\tmp\tinynn\graph\tracer.py", line 1045, in new_func
    result = orig_func(*args, **kwargs)
RuntimeError: The size of tensor a (28) must match the size of tensor b (32) at non-singleton dimension 1
ERROR (tinynn.graph.tracer) inputs: ['input_0_f']
ERROR (tinynn.graph.tracer) forwards: ['head', 'shape_0_f', 'getitem_0_f', 'unsqueeze_0_f', 'attn1', 'attn2', 'getitem_1_f', 'unsqueeze_1_f']
ERROR (tinynn.graph.tracer) outputs: []
ERROR (tinynn.graph.tracer) constants: []

You can see that attn1, conv1, and conv2 are all missed.

@zk1998
Copy link
Collaborator

zk1998 commented Dec 21, 2023 via email

@peterjc123 peterjc123 added the bug Something isn't working label Dec 21, 2023
@peterjc123
Copy link
Collaborator

I think it is probably because torch.sum is unsupported, but I'm not sure. Would you please upload the model so as we could look into it.

@Xz-Alan
Copy link
Author

Xz-Alan commented Dec 22, 2023

Thank you for your reply, but TFLite seems to support the sum operator.
the model I uploaded is obtained with the above code.
model.zip

Looking forward to your reply.

@peterjc123
Copy link
Collaborator

Thank you for your reply, but TFLite seems to support the sum operator. the model I uploaded is obtained with the above code. model.zip

Looking forward to your reply.

I mean the pruner module doesn't support that op.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants