Skip to content

Commit bad0f0f

Browse files
authored
Implement ConvXDTranspose (#853)
This PR implements unified transpose convolution covering 1D/2D/3D, SAME/VALID/CAUSAL and arbitrary padding, arbitrary window, stride, and dilation. SAME and VALID is equivalent to jax.lax.conv_transpose(). CAUSAL is defined in this PR. Each Literal padding follows the formulas below, * SAME: padding=(min(window-1, ceil((w+s-2)/2)), max(stride-1, floor((w+s-2)/2))) pad_total = window+stride-2 when stride > window -> (window-1, stride-1) * VALID: padding=(window-1, max(stride-1, window-1)) pad_total = window+stride-2 + max(window-stride, 0) when stride > window -> (window-1, stride-1) * CAUSAL: padding=(window-1, stride-1) pad_total = window+stride-2 Note: output_size = input_size*stride - (window+stride-2) + pad_total = input_size*stride <- "SAME", "CAUSAL" = input_size*stride + max(window-stride, 0) <- "VALID" Note: In the above equation, `window` can be replaced with `dilate_window` when dilation > 1. dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window() The following illustration demonstrates how Conv Transpose operates, assuming all kernel values are set to 1 for simplicity in showcasing output values. In the window=3 and stride=1 case, this function creates outputs as follows: * "SAME" padding=(1, 1) pad| |pad paddings: 0|0 0 1 1|0 0 0 0 -> 0 0 0 1 -> 1 0 1 1 -> 2 1 1 0 -> 2 * "VALID" padding=(2, 2) pad | |pad paddings: 0 0|0 0 1 1|0 0 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 1 -> 2 1 1 0 -> 2 1 0 0 -> 1 * "CAUSAL" padding=(2, 0) pad | |pad paddings: 0 0|0 0 1 1| 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 1 -> 2 In the window=3 and stride=2 case, this function creates outputs as follows: * "SAME" padding=(2, 1) pad | |pad paddings: 0 0|0 * 0 * 1 * 1|0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 0 -> 1 1 0 1 -> 2 0 1 0 -> 1 * "VALID" padding=(2, 2) pad | |pad paddings: 0 0|0 * 0 * 1 * 1|0 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 0 -> 1 1 0 1 -> 2 0 1 0 -> 1 1 0 0 -> 1 * "CAUSAL" padding=(2, 1) pad | |pad paddings: 0 0|0 * 0 * 1 * 1|0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 0 -> 1 1 0 1 -> 2 0 1 0 -> 1 In the window=3 and stride=3 case, this function creates outputs as follows: * "SAME", "VALID" and "CAUSAL" padding=(2, 2) pad | |pad paddings: 0 0|0 * * 0 * * 1 * * 1|0 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 0 -> 1 1 0 0 -> 1 0 0 1 -> 1 0 1 0 -> 1 1 0 0 -> 1 In the window=3 and stride=4 case, this function creates outputs as follows: * "SAME", "VALID" and "CAUSAL" padding=(2, 3) pad | |pad paddings: 0 0|0 * * * 0 * * * 1 * * * 1|0 0 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 0 -> 1 1 0 0 -> 1 0 0 0 -> 0 0 0 1 -> 1 0 1 0 -> 1 1 0 0 -> 1 0 0 0 -> 0 Here is how to compute output_size, given the above example, 1. |_| -(window-1) 2. |_______________________| (input_size-1)*stride + 1 3. |_| |___| + pad_total So, output_size = -(window-1) + (input_size-1)*stride + 1 + pad_total = input_size*stride - (window+stride-2) + pad_total = input_size*stride <- "SAME", "CAUSAL" = input_size*stride + max(window-stride, 0) <- "VALID" OTHO, when dilation > 1, dilate_window = (window - 1) * dilation + 1. For example, when window=3 and dilation=2, dilate_window=5. In the stride=2 case, this function creates outputs as follows: * "SAME" padding=(3, 2) pad | |pad paddings: 0 0 0|0 * 0 * 1 * 1|0 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 1 -> 1 0 * 0 * 0 -> 0 0 * 1 * 1 -> 2 0 * 0 * 0 -> 0 1 * 1 * 0 -> 2 * "VALID" padding=(4, 4) pad | |pad paddings: 0 0 0 0|0 * 0 * 1 * 1|0 0 0 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 1 -> 1 0 * 0 * 0 -> 0 0 * 1 * 1 -> 2 0 * 0 * 0 -> 0 1 * 1 * 0 -> 2 0 * 0 * 0 -> 0 1 * 0 * 0 -> 1 * "CAUSAL" padding=(4, 1) pad | |pad paddings: 0 0 0 0|0 * 0 * 1 * 1|0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 1 -> 1 0 * 0 * 0 -> 0 0 * 1 * 1 -> 2 0 * 0 * 0 -> 0
1 parent b352259 commit bad0f0f

File tree

6 files changed

+1683
-206
lines changed

6 files changed

+1683
-206
lines changed

axlearn/common/conformer.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from axlearn.common.config import REQUIRED, InstantiableConfig, Required, config_class
3333
from axlearn.common.layers import (
3434
BatchNorm,
35-
DepthwiseConv1D,
35+
Conv1D,
3636
Dropout,
3737
GroupNorm,
3838
LayerNorm,
@@ -72,7 +72,7 @@ class Config(BaseLayer.Config):
7272
linear1_norm: LayerNorm.Config = LayerNorm.default_config()
7373
linear1_activation: tuple[str, str] = ("linear", "nn.sigmoid")
7474
linear1: Linear.Config = Linear.default_config().set(bias=True)
75-
conv: DepthwiseConv1D.Config = DepthwiseConv1D.default_config().set(
75+
conv: Conv1D.Config = Conv1D.default_config().set(
7676
# See Table 2 and 7.
7777
window=32,
7878
bias=False,
@@ -96,7 +96,15 @@ def __init__(self, cfg: Config, *, parent: Module):
9696
cfg.linear1.set(input_dim=cfg.input_dim, output_dim=cfg.input_dim),
9797
)
9898

99-
self._add_child("conv", cfg.conv.set(input_dim=cfg.input_dim))
99+
# Setup Depthwise Convolution (3 dims are same).
100+
self._add_child(
101+
"conv",
102+
cfg.conv.set(
103+
input_dim=cfg.input_dim,
104+
output_dim=cfg.input_dim,
105+
num_input_dim_groups=cfg.input_dim,
106+
),
107+
)
100108
self._add_child("conv_norm", cfg.conv_norm.set(input_dim=cfg.input_dim))
101109
self._add_child(
102110
"linear2",

0 commit comments

Comments
 (0)