Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conv2D IN_CHAN is wrong if GROUPS > 1 #792

Closed
Awpteamoose opened this issue May 25, 2023 · 1 comment · Fixed by #797
Closed

Conv2D IN_CHAN is wrong if GROUPS > 1 #792

Awpteamoose opened this issue May 25, 2023 · 1 comment · Fixed by #797

Comments

@Awpteamoose
Copy link

Awpteamoose commented May 25, 2023

Consider this

// IN_CHAN, OUT_CHAN, KERNEL_SIZE, STRIDE, PADDING, DILATION, GROUPS
Conv2D<16, 32, 2, 2, 0, 1, 16>

In other words, a depthwise convolution with a 2×2 kernel over a 16×W×H volume, expecting to produce 32×W/2×H/2 volume.

One would expect this to compile:

let module = dev.build_module::<Conv2D<16, 32, 2, 2, 0, 1, 16>, f32>();
let in_ten: Tensor<Rank3<16, 32, 32>, _, _> = dev.ones();
let out_ten: Tensor<Rank3<32, 16, 16>, _, _> = Module::forward(&module, in_ten);

But we get the error:

   |
81 | ...d(&module, in_ten);
   |               ^^^^^^ expected `256`, found `16`
   |
   = note: expected constant `256`
              found constant `16`

Changing the code like this works and gives the correct result (checked vs tensorflow's tf.nn.depthwise_conv2d):

let module = dev.build_module::<Conv2D<1, 32, 2, 2, 0, 1, 16>, f32>();
let in_ten: Tensor<Rank3<16, 32, 32>, _, _> = dev.ones();
let out_ten: Tensor<Rank3<32, 16, 16>, _, _> = Module::forward(&module, in_ten);

But it looks like a type-level bug? My number of input channels isn't affected by the number of groups. I understand that each filter only convolves over IN_CHAN/GROUPS number of channels so the "fix" makes sense, but it seems like an implementation detail that shouldn't leak to the interface.

@coreylowman
Copy link
Owner

Ahh yes so this is because the in_chan shape needs to change based on groups. You can see this in pytorch code:

>>> torch.nn.Conv2d(16, 32, 3, groups=1).weight.shape
torch.Size([32, 16, 3, 3])
>>> torch.nn.Conv2d(16, 32, 3, groups=2).weight.shape
torch.Size([32, 8, 3, 3])
>>> torch.nn.Conv2d(16, 32, 3, groups=4).weight.shape
torch.Size([32, 4, 3, 3])

Currently Conv2D doesn't do this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants