Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TransposeConv wrong shape? #262

Open
BmanClark opened this issue Oct 18, 2023 · 15 comments
Open

TransposeConv wrong shape? #262

BmanClark opened this issue Oct 18, 2023 · 15 comments
Labels
awaiting response dependencies question Further information is requested work/medium work that can be done within 1 day

Comments

@BmanClark
Copy link

The TFlite code doesn't like the shape of the Transpose Convolution node I have in my TinyNN converted network. It seems consistent across Transpose Convolutions that it expects the bias to match the size of the first dimension, but it is in fact the size of the last dimension in my TinyNN-converted network. (The error is on the first TransposeConv, but there are a number of similarly-formed ones in the network which will presumably have the same error).

The network runs fine before conversion, and the tflite looks very sensible in netron, but TFlite runtime doesn't like it...

It gives an error of:

Unexpected failure when preparing tensor allocations: tensorflow/lite/kernels/transpose_conv.cc:289 NumElements(bias) != SizeOfDimension(weights, 0) (3 != 1)
Node number 194 (TRANSPOSE_CONV) failed to prepare.

The node in question has (according to netron):

input 1x4x4x3
output shape <4>
Weights <1x2x2x3>
Bias<3>
output 1x8x8x3

which evidently doesn't match TFLite's rules. Is this a conversion error, or how do I cope with it?

Incidentally, the biases in all the Transpose Convolutions are entirely 0s and the Weights a suspiciously simple arrangement of 1s and 0s. Bias being optional for this TFLite operator, is there a way to not include it?

@peterjc123 peterjc123 added the bug Something isn't working label Oct 19, 2023
@peterjc123
Copy link
Collaborator

@BmanClark Could you please tell me which version of Tensorflow (Lite) are you testing against?

@peterjc123
Copy link
Collaborator

It is indeed a conversion error. Could you please upload a TorchScript model and the shape of the inputs so that we can reproduce this issue.

@peterjc123 peterjc123 added the work/medium work that can be done within 1 day label Oct 19, 2023
@BmanClark
Copy link
Author

TFLite is 2.8.0. The line number is out a few from latest, but the code seems pretty much the same in transpose_conv.cc
I'll work on getting a torchscript version of the file now, currently it's in pytorch + loading state.

Error will need fixing, but as a thought, if bias is 0 and optional, it would be more efficient to not include it in tflite model?

@BmanClark
Copy link
Author

Hmm, I've created the Torchscript version, but it's (just!) too big to attach, so I'm still working out how I can share it with you.
input shape is 1x512x512x4xfloat32

@peterjc123
Copy link
Collaborator

TFLite is 2.8.0. The line number is out a few from latest, but the code seems pretty much the same in transpose_conv.cc I'll work on getting a torchscript version of the file now, currently it's in pytorch + loading state.

Error will need fixing, but as a thought, if bias is 0 and optional, it would be more efficient to not include it in tflite model?

As for workaround, yeah, it should be easy. But I'm more interested with how it happened.

Hmm, I've created the Torchscript version, but it's (just!) too big to attach, so I'm still working out how I can share it with you. input shape is 1x512x512x4xfloat32

Google drive, onedrive, either will do if size is an issue. Or you may try out if the model can be traced using our code tracer(usage can be found out in examples/tracer) so that you don't need to upload the weights.

@BmanClark
Copy link
Author

I've put on my google drive and shared with you.

@peterjc123
Copy link
Collaborator

Ex

I've put on my google drive and shared with you.

Thanks. I will take a look tomorrow.

@BmanClark
Copy link
Author

Tinkering further I can get around the problem by specifying group_conv_rewrite=True. This splits each of the Transpose Convolutions into 512? parallel ones and concatenates them again. It runs, but it doesn't look like it's as efficient as it could be... Gives netron a headache too 😄

@peterjc123
Copy link
Collaborator

peterjc123 commented Oct 20, 2023

@BmanClark Hi, I've put up a fix to eliminate the zero bias tensors for the DeConv ops. #263 But I'm not sure if group deconv is supported in TFLite. If that is unsupported, then this fix may be just useless.

@peterjc123
Copy link
Collaborator

Update: looks like group deconvolution is not supported, at least in XNNPACK delegate. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc#L6213

Tinkering further I can get around the problem by specifying group_conv_rewrite=True. This splits each of the Transpose Convolutions into 512? parallel ones and concatenates them again. It runs, but it doesn't look like it's as efficient as it could be... Gives netron a headache too 😄

So this may be the only way to use grouped deconvolutions in TFLite. You may create a new issue in the TFLite repo.

@BmanClark
Copy link
Author

I'm not actually looking to target XNNPACK ultimately (although I would like it as a reference), but I've created an issue: tensorflow/tensorflow#62181
Thanks for bias fix, I'll look to see if that's enough for ArmNN to run, or if I need to get them to fix it too...

@peterjc123
Copy link
Collaborator

I'm not actually looking to target XNNPACK ultimately (although I would like it as a reference), but I've created an issue: tensorflow/tensorflow#62181 Thanks for bias fix, I'll look to see if that's enough for ArmNN to run, or if I need to get them to fix it too...

Actually, I didn't find the changes to enable group deconvolution for the general optimized kernel, either. So it is possibly that it isn't supported by TFLite interpreter. But as for ArmNN, the story may be different since they could support that case since TFLite is only a format for model representation for them.

@BmanClark
Copy link
Author

ArmNN don't currently support group deconvolution either. They might look at adding it for me, but I might not need them to, as the network creator has kindly changed the group deconvolution to an Upsample for me! I love the OSS community!
Anyway, I'm away so not going to get to try until next week, but I can give you an update once I've had a chance to get further. Thanks for your help.

@peterjc123 peterjc123 added awaiting response question Further information is requested and removed bug Something isn't working labels Oct 26, 2023
@BmanClark
Copy link
Author

Update: the network happily converts now that its Group Deconvolutions (aka Transpose Convolutions) are replaced with Upsamples. Thank you. I have had a request for more information on the Tensorflow issue raised (specifically on how I used TinyNN to convert), so I'm providing that.

@peterjc123
Copy link
Collaborator

@BmanClark I have commented on that issue. Glad you solved it the other way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting response dependencies question Further information is requested work/medium work that can be done within 1 day
Projects
None yet
Development

No branches or pull requests

2 participants