diff --git a/imagen_pytorch/imagen_pytorch.py b/imagen_pytorch/imagen_pytorch.py index 6a134cf..bd2918c 100644 --- a/imagen_pytorch/imagen_pytorch.py +++ b/imagen_pytorch/imagen_pytorch.py @@ -319,6 +319,15 @@ def predict_start_from_noise(self, x_t, t, noise): # norms and residuals +class ChanRMSNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim, 1, 1)) + + def forward(self, x): + return F.normalize(x, dim = 1) * self.scale * self.gamma + class LayerNorm(nn.Module): def __init__(self, feats, stable = False, dim = -1): super().__init__() @@ -664,16 +673,15 @@ def __init__( self, dim, dim_out, - groups = 8, norm = True ): super().__init__() - self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity() + self.norm = ChanRMSNorm(dim) if norm else Identity() self.activation = nn.SiLU() self.project = nn.Conv2d(dim, dim_out, 3, padding = 1) def forward(self, x, scale_shift = None): - x = self.groupnorm(x) + x = self.norm(x) if exists(scale_shift): scale, shift = scale_shift @@ -690,7 +698,6 @@ def __init__( *, cond_dim = None, time_cond_dim = None, - groups = 8, linear_attn = False, use_gca = False, squeeze_excite = False, @@ -717,8 +724,8 @@ def __init__( **attn_kwargs ) - self.block1 = Block(dim, dim_out, groups = groups) - self.block2 = Block(dim_out, dim_out, groups = groups) + self.block1 = Block(dim, dim_out) + self.block2 = Block(dim_out, dim_out) self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1) @@ -1133,7 +1140,6 @@ def __init__( cond_on_text = True, max_text_len = 256, init_dim = None, - resnet_groups = 8, init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed init_cross_embed = True, init_cross_embed_kernel_sizes = (3, 7, 15), @@ -1289,7 +1295,6 @@ def __init__( # resnet block klass num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers) - resnet_groups = cast_tuple(resnet_groups, num_layers) resnet_klass = partial(ResnetBlock, **attn_kwargs) @@ -1300,7 +1305,7 @@ def __init__( use_linear_attn = cast_tuple(use_linear_attn, num_layers) use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers) - assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))]) + assert all([layers == num_layers for layers in list(map(len, (layer_attns, layer_cross_attns)))]) # downsample klass @@ -1311,7 +1316,7 @@ def __init__( # initial resnet block (for memory efficient unet) - self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None + self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, use_gca = use_global_context_attn) if memory_efficient else None # scale for resnet skip connections @@ -1323,14 +1328,14 @@ def __init__( self.ups = nn.ModuleList([]) num_resolutions = len(in_out) - layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn] + layer_params = [num_resnet_blocks, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn] reversed_layer_params = list(map(reversed, layer_params)) # downsampling layers skip_connect_dims = [] # keep track of skip connection dimensions - for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)): + for ind, ((dim_in, dim_out), layer_num_resnet_blocks, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)): is_last = ind >= (num_resolutions - 1) layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None @@ -1362,8 +1367,8 @@ def __init__( self.downs.append(nn.ModuleList([ pre_downsample, - resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), - nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), + resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim), + nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs), post_downsample ])) @@ -1372,9 +1377,9 @@ def __init__( mid_dim = dims[-1] - self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) + self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim) self.mid_attn = TransformerBlock(mid_dim, depth = layer_mid_attns_depth, **attn_kwargs) if attend_at_middle else None - self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) + self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim) # upsample klass @@ -1384,7 +1389,7 @@ def __init__( upsample_fmap_dims = [] - for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)): + for ind, ((dim_in, dim_out), layer_num_resnet_blocks, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)): is_last = ind == (len(in_out) - 1) layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None @@ -1401,8 +1406,8 @@ def __init__( upsample_fmap_dims.append(dim_out) self.ups.append(nn.ModuleList([ - resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), - nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), + resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim), + nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs), upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity() ])) @@ -1423,7 +1428,7 @@ def __init__( # final optional resnet block and convolution out - self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None + self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, use_gca = True) if final_resnet_block else None final_conv_dim_in = dim if final_resnet_block else final_conv_dim final_conv_dim_in += (channels if lowres_cond else 0) diff --git a/imagen_pytorch/imagen_video.py b/imagen_pytorch/imagen_video.py index e9f6d15..b8ad573 100644 --- a/imagen_pytorch/imagen_video.py +++ b/imagen_pytorch/imagen_video.py @@ -204,6 +204,15 @@ def forward(self, x): mean = torch.mean(x, dim = -1, keepdim = True) return (x - mean) * (var + eps).rsqrt() * self.g +class ChanRMSNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim, 1, 1, 1)) + + def forward(self, x): + return F.normalize(x, dim = 1) * self.scale * self.gamma + class ChanLayerNorm(nn.Module): def __init__(self, dim, stable = False): super().__init__() @@ -709,11 +718,10 @@ def __init__( self, dim, dim_out, - groups = 8, norm = True ): super().__init__() - self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity() + self.norm = ChanRMSNorm(dim) if norm else Identity() self.activation = nn.SiLU() self.project = Conv3d(dim, dim_out, 3, padding = 1) @@ -723,7 +731,7 @@ def forward( scale_shift = None, ignore_time = False ): - x = self.groupnorm(x) + x = self.norm(x) if exists(scale_shift): scale, shift = scale_shift @@ -740,7 +748,6 @@ def __init__( *, cond_dim = None, time_cond_dim = None, - groups = 8, linear_attn = False, use_gca = False, squeeze_excite = False, @@ -767,8 +774,8 @@ def __init__( **attn_kwargs ) - self.block1 = Block(dim, dim_out, groups = groups) - self.block2 = Block(dim_out, dim_out, groups = groups) + self.block1 = Block(dim, dim_out) + self.block2 = Block(dim_out, dim_out) self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1) @@ -1249,7 +1256,6 @@ def __init__( cond_on_text = True, max_text_len = 256, init_dim = None, - resnet_groups = 8, init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed init_cross_embed = True, init_cross_embed_kernel_sizes = (3, 7, 15), @@ -1412,7 +1418,6 @@ def __init__( # resnet block klass num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers) - resnet_groups = cast_tuple(resnet_groups, num_layers) resnet_klass = partial(ResnetBlock, **attn_kwargs) @@ -1420,7 +1425,7 @@ def __init__( layer_attns_depth = cast_tuple(layer_attns_depth, num_layers) layer_cross_attns = cast_tuple(layer_cross_attns, num_layers) - assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))]) + assert all([layers == num_layers for layers in list(map(len, (layer_attns, layer_cross_attns)))]) # temporal downsample config @@ -1436,7 +1441,7 @@ def __init__( # initial resnet block (for memory efficient unet) - self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None + self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, use_gca = use_global_context_attn) if memory_efficient else None self.init_temporal_peg = temporal_peg(init_dim) self.init_temporal_attn = temporal_attn(init_dim) @@ -1451,14 +1456,14 @@ def __init__( self.ups = nn.ModuleList([]) num_resolutions = len(in_out) - layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, temporal_strides] + layer_params = [num_resnet_blocks, layer_attns, layer_attns_depth, layer_cross_attns, temporal_strides] reversed_layer_params = list(map(reversed, layer_params)) # downsampling layers skip_connect_dims = [] # keep track of skip connection dimensions - for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, temporal_stride) in enumerate(zip(in_out, *layer_params)): + for ind, ((dim_in, dim_out), layer_num_resnet_blocks, layer_attn, layer_attn_depth, layer_cross_attn, temporal_stride) in enumerate(zip(in_out, *layer_params)): is_last = ind >= (num_resolutions - 1) layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn @@ -1486,8 +1491,8 @@ def __init__( self.downs.append(nn.ModuleList([ pre_downsample, - resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), - nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), + resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim), + nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, ff_time_token_shift = ff_time_token_shift, context_dim = cond_dim, **attn_kwargs), temporal_peg(current_dim), temporal_attn(current_dim), @@ -1499,11 +1504,11 @@ def __init__( mid_dim = dims[-1] - self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) + self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim) self.mid_attn = Residual(Attention(mid_dim, **attn_kwargs)) if attend_at_middle else None self.mid_temporal_peg = temporal_peg(mid_dim) self.mid_temporal_attn = temporal_attn(mid_dim) - self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) + self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim) # upsample klass @@ -1513,7 +1518,7 @@ def __init__( upsample_fmap_dims = [] - for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, temporal_stride) in enumerate(zip(reversed(in_out), *reversed_layer_params)): + for ind, ((dim_in, dim_out), layer_num_resnet_blocks, layer_attn, layer_attn_depth, layer_cross_attn, temporal_stride) in enumerate(zip(reversed(in_out), *reversed_layer_params)): is_last = ind == (len(in_out) - 1) layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None @@ -1524,8 +1529,8 @@ def __init__( upsample_fmap_dims.append(dim_out) self.ups.append(nn.ModuleList([ - resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), - nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), + resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim), + nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, ff_time_token_shift = ff_time_token_shift, context_dim = cond_dim, **attn_kwargs), temporal_peg(dim_out), temporal_attn(dim_out), @@ -1549,7 +1554,7 @@ def __init__( # final optional resnet block and convolution out - self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None + self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, use_gca = True) if final_resnet_block else None final_conv_dim_in = dim if final_resnet_block else final_conv_dim final_conv_dim_in += (channels if lowres_cond else 0) diff --git a/imagen_pytorch/version.py b/imagen_pytorch/version.py index 22db8a6..afced14 100644 --- a/imagen_pytorch/version.py +++ b/imagen_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.26.3' +__version__ = '2.0.0'