Commit bad0f0f
authored
Implement ConvXDTranspose (#853)
This PR implements unified transpose convolution covering 1D/2D/3D,
SAME/VALID/CAUSAL and arbitrary padding, arbitrary window, stride, and
dilation.
SAME and VALID is equivalent to jax.lax.conv_transpose(). CAUSAL is defined in
this PR.
Each Literal padding follows the formulas below,
* SAME: padding=(min(window-1, ceil((w+s-2)/2)), max(stride-1, floor((w+s-2)/2)))
pad_total = window+stride-2
when stride > window -> (window-1, stride-1)
* VALID: padding=(window-1, max(stride-1, window-1))
pad_total = window+stride-2 + max(window-stride, 0)
when stride > window -> (window-1, stride-1)
* CAUSAL: padding=(window-1, stride-1)
pad_total = window+stride-2
Note: output_size = input_size*stride - (window+stride-2) + pad_total
= input_size*stride <- "SAME", "CAUSAL"
= input_size*stride + max(window-stride, 0) <- "VALID"
Note: In the above equation, `window` can be replaced with `dilate_window` when dilation > 1.
dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window()
The following illustration demonstrates how Conv Transpose operates, assuming all kernel values
are set to 1 for simplicity in showcasing output values.
In the window=3 and stride=1 case, this function creates outputs as follows:
* "SAME" padding=(1, 1)
pad| |pad
paddings: 0|0 0 1 1|0
0 0 0 -> 0
0 0 1 -> 1
0 1 1 -> 2
1 1 0 -> 2
* "VALID" padding=(2, 2)
pad | |pad
paddings: 0 0|0 0 1 1|0 0
0 0 0 -> 0
0 0 0 -> 0
0 0 1 -> 1
0 1 1 -> 2
1 1 0 -> 2
1 0 0 -> 1
* "CAUSAL" padding=(2, 0)
pad | |pad
paddings: 0 0|0 0 1 1|
0 0 0 -> 0
0 0 0 -> 0
0 0 1 -> 1
0 1 1 -> 2
In the window=3 and stride=2 case, this function creates outputs as follows:
* "SAME" padding=(2, 1)
pad | |pad
paddings: 0 0|0 * 0 * 1 * 1|0
0 0 0 -> 0
0 0 0 -> 0
0 0 0 -> 0
0 0 0 -> 0
0 0 1 -> 1
0 1 0 -> 1
1 0 1 -> 2
0 1 0 -> 1
* "VALID" padding=(2, 2)
pad | |pad
paddings: 0 0|0 * 0 * 1 * 1|0 0
0 0 0 -> 0
0 0 0 -> 0
0 0 0 -> 0
0 0 0 -> 0
0 0 1 -> 1
0 1 0 -> 1
1 0 1 -> 2
0 1 0 -> 1
1 0 0 -> 1
* "CAUSAL" padding=(2, 1)
pad | |pad
paddings: 0 0|0 * 0 * 1 * 1|0
0 0 0 -> 0
0 0 0 -> 0
0 0 0 -> 0
0 0 0 -> 0
0 0 1 -> 1
0 1 0 -> 1
1 0 1 -> 2
0 1 0 -> 1
In the window=3 and stride=3 case, this function creates outputs as follows:
* "SAME", "VALID" and "CAUSAL" padding=(2, 2)
pad | |pad
paddings: 0 0|0 * * 0 * * 1 * * 1|0 0
0 0 0 -> 0
0 0 0 -> 0
0 0 0 -> 0
0 0 0 -> 0
0 0 0 -> 0
0 0 0 -> 0
0 0 1 -> 1
0 1 0 -> 1
1 0 0 -> 1
0 0 1 -> 1
0 1 0 -> 1
1 0 0 -> 1
In the window=3 and stride=4 case, this function creates outputs as follows:
* "SAME", "VALID" and "CAUSAL" padding=(2, 3)
pad | |pad
paddings: 0 0|0 * * * 0 * * * 1 * * * 1|0 0 0
0 0 0 -> 0
0 0 0 -> 0
0 0 0 -> 0
0 0 0 -> 0
0 0 0 -> 0
0 0 0 -> 0
0 0 0 -> 0
0 0 0 -> 0
0 0 1 -> 1
0 1 0 -> 1
1 0 0 -> 1
0 0 0 -> 0
0 0 1 -> 1
0 1 0 -> 1
1 0 0 -> 1
0 0 0 -> 0
Here is how to compute output_size, given the above example,
1. |_| -(window-1)
2. |_______________________| (input_size-1)*stride + 1
3. |_| |___| + pad_total
So, output_size = -(window-1) + (input_size-1)*stride + 1 + pad_total
= input_size*stride - (window+stride-2) + pad_total
= input_size*stride <- "SAME", "CAUSAL"
= input_size*stride + max(window-stride, 0) <- "VALID"
OTHO, when dilation > 1, dilate_window = (window - 1) * dilation + 1.
For example, when window=3 and dilation=2, dilate_window=5.
In the stride=2 case, this function creates outputs as follows:
* "SAME" padding=(3, 2)
pad | |pad
paddings: 0 0 0|0 * 0 * 1 * 1|0 0
0 * 0 * 0 -> 0
0 * 0 * 0 -> 0
0 * 0 * 0 -> 0
0 * 0 * 1 -> 1
0 * 0 * 0 -> 0
0 * 1 * 1 -> 2
0 * 0 * 0 -> 0
1 * 1 * 0 -> 2
* "VALID" padding=(4, 4)
pad | |pad
paddings: 0 0 0 0|0 * 0 * 1 * 1|0 0 0 0
0 * 0 * 0 -> 0
0 * 0 * 0 -> 0
0 * 0 * 0 -> 0
0 * 0 * 0 -> 0
0 * 0 * 1 -> 1
0 * 0 * 0 -> 0
0 * 1 * 1 -> 2
0 * 0 * 0 -> 0
1 * 1 * 0 -> 2
0 * 0 * 0 -> 0
1 * 0 * 0 -> 1
* "CAUSAL" padding=(4, 1)
pad | |pad
paddings: 0 0 0 0|0 * 0 * 1 * 1|0
0 * 0 * 0 -> 0
0 * 0 * 0 -> 0
0 * 0 * 0 -> 0
0 * 0 * 0 -> 0
0 * 0 * 1 -> 1
0 * 0 * 0 -> 0
0 * 1 * 1 -> 2
0 * 0 * 0 -> 01 parent b352259 commit bad0f0f
File tree
6 files changed
+1683
-206
lines changed- axlearn
- common
- experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer
- vision
6 files changed
+1683
-206
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
32 | 32 | | |
33 | 33 | | |
34 | 34 | | |
35 | | - | |
| 35 | + | |
36 | 36 | | |
37 | 37 | | |
38 | 38 | | |
| |||
72 | 72 | | |
73 | 73 | | |
74 | 74 | | |
75 | | - | |
| 75 | + | |
76 | 76 | | |
77 | 77 | | |
78 | 78 | | |
| |||
96 | 96 | | |
97 | 97 | | |
98 | 98 | | |
99 | | - | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
100 | 108 | | |
101 | 109 | | |
102 | 110 | | |
| |||
0 commit comments