Skip to content

Commit

Permalink
start removing groupnorms because of https://arxiv.org/abs/2312.02696
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 3, 2024
1 parent dfc5d53 commit c166739
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 41 deletions.
45 changes: 25 additions & 20 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,15 @@ def predict_start_from_noise(self, x_t, t, noise):

# norms and residuals

class ChanRMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim, 1, 1))

def forward(self, x):
return F.normalize(x, dim = 1) * self.scale * self.gamma

class LayerNorm(nn.Module):
def __init__(self, feats, stable = False, dim = -1):
super().__init__()
Expand Down Expand Up @@ -664,16 +673,15 @@ def __init__(
self,
dim,
dim_out,
groups = 8,
norm = True
):
super().__init__()
self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()
self.norm = ChanRMSNorm(dim) if norm else Identity()
self.activation = nn.SiLU()
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)

def forward(self, x, scale_shift = None):
x = self.groupnorm(x)
x = self.norm(x)

if exists(scale_shift):
scale, shift = scale_shift
Expand All @@ -690,7 +698,6 @@ def __init__(
*,
cond_dim = None,
time_cond_dim = None,
groups = 8,
linear_attn = False,
use_gca = False,
squeeze_excite = False,
Expand All @@ -717,8 +724,8 @@ def __init__(
**attn_kwargs
)

self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.block1 = Block(dim, dim_out)
self.block2 = Block(dim_out, dim_out)

self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)

Expand Down Expand Up @@ -1133,7 +1140,6 @@ def __init__(
cond_on_text = True,
max_text_len = 256,
init_dim = None,
resnet_groups = 8,
init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed
init_cross_embed = True,
init_cross_embed_kernel_sizes = (3, 7, 15),
Expand Down Expand Up @@ -1289,7 +1295,6 @@ def __init__(
# resnet block klass

num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers)
resnet_groups = cast_tuple(resnet_groups, num_layers)

resnet_klass = partial(ResnetBlock, **attn_kwargs)

Expand All @@ -1300,7 +1305,7 @@ def __init__(
use_linear_attn = cast_tuple(use_linear_attn, num_layers)
use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers)

assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))])
assert all([layers == num_layers for layers in list(map(len, (layer_attns, layer_cross_attns)))])

# downsample klass

Expand All @@ -1311,7 +1316,7 @@ def __init__(

# initial resnet block (for memory efficient unet)

self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None
self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, use_gca = use_global_context_attn) if memory_efficient else None

# scale for resnet skip connections

Expand All @@ -1323,14 +1328,14 @@ def __init__(
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)

layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn]
layer_params = [num_resnet_blocks, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn]
reversed_layer_params = list(map(reversed, layer_params))

# downsampling layers

skip_connect_dims = [] # keep track of skip connection dimensions

for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)):
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)):
is_last = ind >= (num_resolutions - 1)

layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
Expand Down Expand Up @@ -1362,8 +1367,8 @@ def __init__(

self.downs.append(nn.ModuleList([
pre_downsample,
resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim),
nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
post_downsample
]))
Expand All @@ -1372,9 +1377,9 @@ def __init__(

mid_dim = dims[-1]

self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
self.mid_attn = TransformerBlock(mid_dim, depth = layer_mid_attns_depth, **attn_kwargs) if attend_at_middle else None
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)

# upsample klass

Expand All @@ -1384,7 +1389,7 @@ def __init__(

upsample_fmap_dims = []

for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
is_last = ind == (len(in_out) - 1)

layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
Expand All @@ -1401,8 +1406,8 @@ def __init__(
upsample_fmap_dims.append(dim_out)

self.ups.append(nn.ModuleList([
resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity()
]))
Expand All @@ -1423,7 +1428,7 @@ def __init__(

# final optional resnet block and convolution out

self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None
self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, use_gca = True) if final_resnet_block else None

final_conv_dim_in = dim if final_resnet_block else final_conv_dim
final_conv_dim_in += (channels if lowres_cond else 0)
Expand Down
45 changes: 25 additions & 20 deletions imagen_pytorch/imagen_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,15 @@ def forward(self, x):
mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + eps).rsqrt() * self.g

class ChanRMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim, 1, 1, 1))

def forward(self, x):
return F.normalize(x, dim = 1) * self.scale * self.gamma

class ChanLayerNorm(nn.Module):
def __init__(self, dim, stable = False):
super().__init__()
Expand Down Expand Up @@ -709,11 +718,10 @@ def __init__(
self,
dim,
dim_out,
groups = 8,
norm = True
):
super().__init__()
self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()
self.norm = ChanRMSNorm(dim) if norm else Identity()
self.activation = nn.SiLU()
self.project = Conv3d(dim, dim_out, 3, padding = 1)

Expand All @@ -723,7 +731,7 @@ def forward(
scale_shift = None,
ignore_time = False
):
x = self.groupnorm(x)
x = self.norm(x)

if exists(scale_shift):
scale, shift = scale_shift
Expand All @@ -740,7 +748,6 @@ def __init__(
*,
cond_dim = None,
time_cond_dim = None,
groups = 8,
linear_attn = False,
use_gca = False,
squeeze_excite = False,
Expand All @@ -767,8 +774,8 @@ def __init__(
**attn_kwargs
)

self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.block1 = Block(dim, dim_out)
self.block2 = Block(dim_out, dim_out)

self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)

Expand Down Expand Up @@ -1249,7 +1256,6 @@ def __init__(
cond_on_text = True,
max_text_len = 256,
init_dim = None,
resnet_groups = 8,
init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed
init_cross_embed = True,
init_cross_embed_kernel_sizes = (3, 7, 15),
Expand Down Expand Up @@ -1412,15 +1418,14 @@ def __init__(
# resnet block klass

num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers)
resnet_groups = cast_tuple(resnet_groups, num_layers)

resnet_klass = partial(ResnetBlock, **attn_kwargs)

layer_attns = cast_tuple(layer_attns, num_layers)
layer_attns_depth = cast_tuple(layer_attns_depth, num_layers)
layer_cross_attns = cast_tuple(layer_cross_attns, num_layers)

assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))])
assert all([layers == num_layers for layers in list(map(len, (layer_attns, layer_cross_attns)))])

# temporal downsample config

Expand All @@ -1436,7 +1441,7 @@ def __init__(

# initial resnet block (for memory efficient unet)

self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None
self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, use_gca = use_global_context_attn) if memory_efficient else None

self.init_temporal_peg = temporal_peg(init_dim)
self.init_temporal_attn = temporal_attn(init_dim)
Expand All @@ -1451,14 +1456,14 @@ def __init__(
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)

layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, temporal_strides]
layer_params = [num_resnet_blocks, layer_attns, layer_attns_depth, layer_cross_attns, temporal_strides]
reversed_layer_params = list(map(reversed, layer_params))

# downsampling layers

skip_connect_dims = [] # keep track of skip connection dimensions

for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, temporal_stride) in enumerate(zip(in_out, *layer_params)):
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, layer_attn, layer_attn_depth, layer_cross_attn, temporal_stride) in enumerate(zip(in_out, *layer_params)):
is_last = ind >= (num_resolutions - 1)

layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn
Expand Down Expand Up @@ -1486,8 +1491,8 @@ def __init__(

self.downs.append(nn.ModuleList([
pre_downsample,
resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim),
nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, ff_time_token_shift = ff_time_token_shift, context_dim = cond_dim, **attn_kwargs),
temporal_peg(current_dim),
temporal_attn(current_dim),
Expand All @@ -1499,11 +1504,11 @@ def __init__(

mid_dim = dims[-1]

self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
self.mid_attn = Residual(Attention(mid_dim, **attn_kwargs)) if attend_at_middle else None
self.mid_temporal_peg = temporal_peg(mid_dim)
self.mid_temporal_attn = temporal_attn(mid_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)

# upsample klass

Expand All @@ -1513,7 +1518,7 @@ def __init__(

upsample_fmap_dims = []

for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, temporal_stride) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, layer_attn, layer_attn_depth, layer_cross_attn, temporal_stride) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
is_last = ind == (len(in_out) - 1)
layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
Expand All @@ -1524,8 +1529,8 @@ def __init__(
upsample_fmap_dims.append(dim_out)

self.ups.append(nn.ModuleList([
resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, ff_time_token_shift = ff_time_token_shift, context_dim = cond_dim, **attn_kwargs),
temporal_peg(dim_out),
temporal_attn(dim_out),
Expand All @@ -1549,7 +1554,7 @@ def __init__(

# final optional resnet block and convolution out

self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None
self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, use_gca = True) if final_resnet_block else None

final_conv_dim_in = dim if final_resnet_block else final_conv_dim
final_conv_dim_in += (channels if lowres_cond else 0)
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.26.3'
__version__ = '2.0.0'

8 comments on commit c166739

@danbochman
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work better even without the magnitude preserving layers? I got the impression they are supposed to work together

@lucidrains
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danbochman it isn't just Karras' paper

a Brain researcher told me years ago to be cautious around using groupnorms, which I ignored at the time. a recent issue in another repo tipped me over the edge

@danbochman
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucidrains thanks for the reply and reference, I will give it a go and let you know how it works
I also got rid of groupnorms in the past and replaced them with adaptive groupnorm from k-diffusion and it really helped
But this is a much more parameter friendly alternative

@lucidrains
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danbochman cool, let me know what you see on your end!

@danbochman
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucidrains It is too early to say something about the end quality results, but in terms of early convergence it converged earlier (x4 faster) and looks quite good for simple examples.

Virtual Try-On task
Ground Truth | Input Person | Input Garment | Model Output
image

@lucidrains
Copy link
Owner Author

@lucidrains lucidrains commented on c166739 May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! leave it to Karras to point out something everyone commonly uses is defective..

@danbochman
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following back on this to report that with long training times, colors start to get consistently shifted

image
image
image

Might be related to some other random training dynamic, but unfortunately for now I am reverting back to (adaptive) group norms.

@lucidrains
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the data point

Please sign in to comment.