Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider Making Normalization Optional in MappingFeedForwardBlock #94

Open
mnslarcher opened this issue Feb 1, 2024 · 2 comments
Open

Comments

@mnslarcher
Copy link

Hi there!

I've noticed that in the forward method of MappingNetwork, you apply RMSNorm to the input:

class MappingNetwork(nn.Module):
    def __init__(self, n_layers, d_model, d_ff, dropout=0.0):
        super().__init__()
        self.in_norm = RMSNorm(d_model)
        self.blocks = nn.ModuleList([MappingFeedForwardBlock(d_model, d_ff, dropout=dropout) for _ in range(n_layers)])
        self.out_norm = RMSNorm(d_model)

    def forward(self, x):
        x = self.in_norm(x)
        for block in self.blocks:
            x = block(x)
        x = self.out_norm(x)
        return x

However, MappingFeedForwardBlock also performs normalization, which means the first block normalizes input that has already been normalized. Here's the current implementation for reference:

class MappingFeedForwardBlock(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.0):
        super().__init__()
        self.norm = RMSNorm(d_model)
        self.up_proj = apply_wd(LinearGEGLU(d_model, d_ff, bias=False))
        self.dropout = nn.Dropout(dropout)
        self.down_proj = apply_wd(zero_init(Linear(d_ff, d_model, bias=False)))

    def forward(self, x):
        skip = x
        x = self.norm(x)
        x = self.up_proj(x)
        x = self.dropout(x)
        x = self.down_proj(x)
        return x + skip

Wouldn't it make sense to introduce an option to toggle normalization in MappingFeedForwardBlock and turn it off for the first block?

To be honest, even with this setup, the RMSNorm in the first block still plays a role, as there could be different scales for skip and x.

Just a thought while reviewing the code – feel free to ignore if it's not relevant!

@stefan-baumann
Copy link

Good catch! Yeah, we should consider changing this, although at this point (given that it also seems to work well with the double norm), we also have to consider whether it's worth introducing potentially breaking changes for the public implementation.

@mnslarcher
Copy link
Author

Sure, it's not a big deal, maybe I'd consider adding something like use_norm=True or norm=True in the MappingFeedForwardBlock. It'll keep things as they are for now, but later on, if you wanna turn off normalization when reusing the block, you'll have the option. Anyway, it's pretty much a ~0 impact thing, probably not worth making the code more complex.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants