Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds OUTPUT_PADDING to ConvTrans2D #890

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Adds OUTPUT_PADDING to ConvTrans2D
- Draft state.
- Unsure if correct, but a very simple and quick test gives the same
  result from pytorch.
- Note: Tensorflow result differs, both from dfdx and from pytorch.

Reference pytorch test:
```python
import torch

x = np.array([[[[0.1, 0.7], [0.3, 0.4]]]])
w = np.array([[[[-0.1, -0.3, 0.7], [0.8, -0.2, 0.1], [0.3, 0.4, -0.5]]]])

a = torch.nn.ConvTranspose2d(output_padding=0, in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=1, bias = False)
b = torch.nn.ConvTranspose2d(output_padding=1, in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=1, bias = False)

x = torch.from_numpy(x).float()
w0 = torch.from_numpy(w).float()

with torch.no_grad():
    a.weight = torch.nn.Parameter(w0)
    b.weight = torch.nn.Parameter(w0)

ya = a(x)
yb = b(x)

print(ya.size()) # torch.Size([1, 1, 3, 3])
print(yb.size()) # torch.Size([1, 1, 4, 4])

print(ya)
print(yb)
```
  • Loading branch information
swfsql committed Mar 1, 2024
commit e81228c300a8a48c4e257bdaeb71c46fcc8b18be
91 changes: 61 additions & 30 deletions dfdx-core/src/tensor_ops/convtrans2d/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub(super) trait ConvTrans2DKernel<E: Dtype>: Storage<E> {
) -> Result<(), Error>;
}

pub trait TryConvTrans2D<Stride, Padding, Dilation, Groups>: Sized {
pub trait TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>: Sized {
type Convolved;

/// Applies a 2D convolution to the input tensor.
Expand All @@ -61,8 +61,9 @@ pub trait TryConvTrans2D<Stride, Padding, Dilation, Groups>: Sized {
padding: Padding,
dilation: Dilation,
groups: Groups,
output_padding: OutputPadding,
) -> Self::Convolved {
self.try_convtrans2d(stride, padding, dilation, groups)
self.try_convtrans2d(stride, padding, dilation, groups, output_padding)
.unwrap()
}

Expand All @@ -73,6 +74,7 @@ pub trait TryConvTrans2D<Stride, Padding, Dilation, Groups>: Sized {
padding: Padding,
dilation: Dilation,
groups: Groups,
output_padding: OutputPadding,
) -> Result<Self::Convolved, Error>;
}

Expand All @@ -82,27 +84,31 @@ impl<
const PADDING: usize,
const DILATION: usize,
Groups: Dim,
const OUTPUT_PADDING: usize,
const DIM: usize,
> TryConvTrans2D<Const<STRIDE>, Const<PADDING>, Const<DILATION>, Groups>
> TryConvTrans2D<Const<STRIDE>, Const<PADDING>, Const<DILATION>, Groups, Const<OUTPUT_PADDING>>
for (Const<DIM>, Const<KERNEL>)
where
Const<{ (DIM - 1) * STRIDE - 2 * PADDING + DILATION * (KERNEL - 1) + 1 }>: Sized,
Const<{ (DIM - 1) * STRIDE - 2 * PADDING + DILATION * (KERNEL - 1) + 1 + OUTPUT_PADDING }>:
Sized,
{
type Convolved = Const<{ (DIM - 1) * STRIDE - 2 * PADDING + DILATION * (KERNEL - 1) + 1 }>;
type Convolved =
Const<{ (DIM - 1) * STRIDE - 2 * PADDING + DILATION * (KERNEL - 1) + 1 + OUTPUT_PADDING }>;

fn try_convtrans2d(
self,
_: Const<STRIDE>,
_: Const<PADDING>,
_: Const<DILATION>,
_: Groups,
_: Const<OUTPUT_PADDING>,
) -> Result<Self::Convolved, Error> {
Ok(Const)
}
}

impl<Kernel: Dim, Stride: Dim, Padding: Dim, Dilation: Dim, Groups: Dim>
TryConvTrans2D<Stride, Padding, Dilation, Groups> for (usize, Kernel)
impl<Kernel: Dim, Stride: Dim, Padding: Dim, Dilation: Dim, Groups: Dim, OutputPadding: Dim>
TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding> for (usize, Kernel)
{
type Convolved = usize;

Expand All @@ -112,18 +118,33 @@ impl<Kernel: Dim, Stride: Dim, Padding: Dim, Dilation: Dim, Groups: Dim>
padding: Padding,
dilation: Dilation,
_: Groups,
output_padding: OutputPadding,
) -> Result<Self::Convolved, Error> {
let (dim, kernel) = self;
Ok(
((dim - 1) * stride.size() + dilation.size() * (kernel.size() - 1) + 1)
.checked_sub(2 * padding.size())
.unwrap(),
)
Ok(((dim - 1) * stride.size()
+ dilation.size() * (kernel.size() - 1)
+ 1
+ output_padding.size())
.checked_sub(2 * padding.size())
.unwrap())
}
}

impl<InpChan, OutChanOverGroups, Kernel, Stride, Padding, Dilation, Groups, H, W, E, D, T>
TryConvTrans2D<Stride, Padding, Dilation, Groups>
impl<
InpChan,
OutChanOverGroups,
Kernel,
Stride,
Padding,
Dilation,
Groups,
OutputPadding,
H,
W,
E,
D,
T,
> TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>
for (
Tensor<(InpChan, H, W), E, D, T>,
Tensor<(InpChan, OutChanOverGroups, Kernel, Kernel), E, D>,
Expand All @@ -136,23 +157,26 @@ where
Padding: Dim,
Dilation: Dim,
Groups: Dim,
OutputPadding: Dim,
H: Dim,
W: Dim,
E: Dtype,
D: ConvTrans2DKernel<E> + crate::tensor_ops::reshape_to::ReshapeKernel<E>,
T: Tape<E, D>,
OutChanOverGroups: std::ops::Mul<Groups>,
<OutChanOverGroups as std::ops::Mul<Groups>>::Output: Dim,
(H, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups>,
(W, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups>,
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved: Dim,
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved: Dim,
(H, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>,
(W, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>,
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved:
Dim,
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved:
Dim,
{
type Convolved = Tensor<
(
<OutChanOverGroups as std::ops::Mul<Groups>>::Output,
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved,
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved,
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved,
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved,
),
E,
D,
Expand All @@ -165,11 +189,13 @@ where
padding: Padding,
dilation: Dilation,
groups: Groups,
output_padding: OutputPadding,
) -> Result<Self::Convolved, Error> {
let (img, filters) = self;
let (inp_chan, h, w) = img.shape;
let img = img.try_reshape_like(&(Const::<1>, inp_chan, h, w))?;
let out = (img, filters).try_convtrans2d(stride, padding, dilation, groups)?;
let out =
(img, filters).try_convtrans2d(stride, padding, dilation, groups, output_padding)?;
let (_, out_chan, out_h, out_w) = out.shape;
out.try_reshape_like(&(out_chan, out_h, out_w))
}
Expand All @@ -182,13 +208,14 @@ impl<
Padding,
Dilation,
Groups,
OutputPadding,
Batch,
H,
W,
E,
D,
T,
> TryConvTrans2D<Stride, Padding, Dilation, Groups>
> TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>
for (
Tensor<(Batch, InpChan, H, W), E, D, T>,
Tensor<(InpChan, OutChanOverGroups, Kernel, Kernel), E, D>,
Expand All @@ -201,6 +228,7 @@ where
Padding: Dim,
Dilation: Dim,
Groups: Dim,
OutputPadding: Dim,
Batch: Dim,
H: Dim,
W: Dim,
Expand All @@ -209,17 +237,19 @@ where
T: Tape<E, D>,
OutChanOverGroups: std::ops::Mul<Groups>,
<OutChanOverGroups as std::ops::Mul<Groups>>::Output: Dim,
(H, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups>,
(W, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups>,
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved: Dim,
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved: Dim,
(H, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>,
(W, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>,
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved:
Dim,
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved:
Dim,
{
type Convolved = Tensor<
(
Batch,
<OutChanOverGroups as std::ops::Mul<Groups>>::Output,
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved,
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved,
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved,
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved,
),
E,
D,
Expand All @@ -232,6 +262,7 @@ where
padding: Padding,
dilation: Dilation,
groups: Groups,
output_padding: OutputPadding,
) -> Result<Self::Convolved, Error> {
let (img, filters) = self;
assert_eq!(img.shape.1, filters.shape.0);
Expand All @@ -242,8 +273,8 @@ where
if img.strides != img.shape.strides() || filters.strides != filters.shape.strides() {
panic!("Image & filter inputs to conv2d must be contiguous");
}
let h_out = (h, kernel).convtrans2d(stride, padding, dilation, groups);
let w_out = (w, kernel).convtrans2d(stride, padding, dilation, groups);
let h_out = (h, kernel).convtrans2d(stride, padding, dilation, groups, output_padding);
let w_out = (w, kernel).convtrans2d(stride, padding, dilation, groups, output_padding);
let op = ConvTrans2DOp {
stride: stride.size(),
padding: padding.size(),
Expand Down
28 changes: 14 additions & 14 deletions dfdx-core/src/tensor_ops/convtrans2d/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ fn test_convtrans2d_default() {
],
])
.to_dtype::<TestDtype>();
let y =
(x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<1>);
let y = (x.leaky_trace(), w.clone())
.convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<1>, Const::<0>);
#[rustfmt::skip]
assert_close_to_literal!(
y,
Expand Down Expand Up @@ -125,8 +125,8 @@ fn test_convtrans2d_stride_2() {
],
])
.to_dtype::<TestDtype>();
let y =
(x.leaky_trace(), w.clone()).convtrans2d(Const::<2>, Const::<0>, Const::<1>, Const::<1>);
let y = (x.leaky_trace(), w.clone())
.convtrans2d(Const::<2>, Const::<0>, Const::<1>, Const::<1>, Const::<0>);
#[rustfmt::skip]
assert_close_to_literal!(
y,
Expand Down Expand Up @@ -223,8 +223,8 @@ fn test_convtrans2d_padded() {
],
])
.to_dtype::<TestDtype>();
let y =
(x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<1>, Const::<1>, Const::<1>);
let y = (x.leaky_trace(), w.clone())
.convtrans2d(Const::<1>, Const::<1>, Const::<1>, Const::<1>, Const::<0>);
assert_close_to_literal!(
y,
[
Expand Down Expand Up @@ -283,8 +283,8 @@ fn test_convtrans2d_batched() {
let x: Tensor<Rank3<3, 28, 28>, TestDtype, _> = dev.sample_normal();
let w: Tensor<Rank4<3, 5, 6, 6>, TestDtype, _> = dev.sample_normal();

let y: Tensor<Rank3<5, 83, 83>, _, _, _> =
(x.leaky_trace(), w.clone()).convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>);
let y: Tensor<Rank3<5, 83, 83>, _, _, _> = (x.leaky_trace(), w.clone())
.convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>, Const::<0>);
let y0 = y.retaped::<NoneTape>();
let grads0 = y.square().mean().backward();
let x0 = grads0.get(&x);
Expand All @@ -294,8 +294,8 @@ fn test_convtrans2d_batched() {
.broadcast::<Rank4<10, 3, 28, 28>, _>()
.reshape::<Rank4<10, 3, 28, 28>>();

let y: Tensor<Rank4<10, 5, 83, 83>, _, _, _> =
(x.leaky_trace(), w.clone()).convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>);
let y: Tensor<Rank4<10, 5, 83, 83>, _, _, _> = (x.leaky_trace(), w.clone())
.convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>, Const::<0>);
for i in 0..10 {
assert_close_to_tensor!(y0, y.retaped::<NoneTape>().select(dev.tensor(i)), 1e-5);
}
Expand Down Expand Up @@ -341,8 +341,8 @@ fn test_convtrans2d_grouped() {
],
])
.to_dtype::<TestDtype>();
let y =
(x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<2>);
let y = (x.leaky_trace(), w.clone())
.convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<2>, Const::<0>);
#[rustfmt::skip]
assert_close_to_literal!(
y,
Expand Down Expand Up @@ -451,8 +451,8 @@ fn test_convtrans2d_dilated() {
],
])
.to_dtype::<TestDtype>();
let y =
(x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<0>, Const::<2>, Const::<1>);
let y = (x.leaky_trace(), w.clone())
.convtrans2d(Const::<1>, Const::<0>, Const::<2>, Const::<1>, Const::<0>);
#[rustfmt::skip]
assert_close_to_literal!(
y,
Expand Down
Loading
Loading