Skip to content

clone HF's GPT2 to create GPTMeg with a few tiny changes. #138

@stas00

Description

@stas00

As can be seen from #121 we have a divergence between Meg and HF GPT2, while using the same weights under fp16.

So the proposed solution to enable users to use BigScience-pretrained models is to create a new architecture, which would be an identical clone of HF's GPT2, but with some changes.

Here are 3 changes:

def apply_overrides():

    # 1. layer norm needs to be done in fp32 and then cast back to fp16 to match meg.
    torch_layer_norm_orig = torch.layer_norm
    def torch_layer_norm_force_fp32(input, normalized_shape, weight, bias, eps, cuddn):
        out = torch_layer_norm_orig(input.float(), normalized_shape, weight.float(), bias.float(), eps, torch.backends.cudnn.enabled).half()
        print(out)
        #die
        return out
    torch.layer_norm = torch_layer_norm_force_fp32


    # 2. MLP uses a slightly different activation function with a custom bwd
    import transformers.activations
    @torch.jit.script
    def gelu_megatron_fwd(x):
        return  x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))

    @torch.jit.script
    def gelu_megatron_bwd(g, x):
        tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
        # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
        ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
        return ff*g

    class GeLUFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input):
            ctx.save_for_backward(input)
            return gelu_megatron_fwd(input)

        @staticmethod
        def backward(ctx, grad_output):
            input = ctx.saved_tensors
            tmp = gelu_megatron_bwd(grad_output, input)
            return tmp, tmp

    transformers.activations.gelu_fast = GeLUFunction.apply
    transformers.activations.ACT2FN["gelu_fast"] = transformers.activations.gelu_fast


    # 3. torch.baddbmm() (meg) produces slightly different results than torch.matmul, so override to use `torch.baddbmm`
    import transformers.models.gpt2.modeling_gpt2
    from torch import nn
    def new_attn(self, query, key, value, attention_mask=None, head_mask=None):
        output_size = (query.size(0), key.size(1), query.size(2), key.size(2))
        matmul_result = torch.empty(output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query.dtype, device=query.device)

        factor = float(value.size(-1)) ** 0.5
        matmul_result = torch.baddbmm(
            matmul_result,
            query.reshape(-1, query.shape[2], query.shape[3]),  # [b * np, sq, hn]
            key.reshape(-1, query.shape[2], query.shape[3]).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0,
            alpha=1.0 / factor
        )
        attn_weights = matmul_result.view(*output_size)

        # attn_weights = torch.matmul(query, key.transpose(-1, -2))
        #
        # if self.scale_attn_weights:
        #     attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)

        # Layer-wise attention scaling
        if self.scale_attn_by_inverse_layer_idx:
            attn_weights = attn_weights / float(self.layer_idx + 1)

        if not self.is_cross_attention:
            # if only "normal" attention layer implements causal mask
            query_length, key_length = query.size(-2), key.size(-2)
            causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
            attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))

        if attention_mask is not None:
            # Apply the attention mask
            attn_weights = attn_weights + attention_mask

        attn_weights = nn.Softmax(dim=-1)(attn_weights)

        # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
        attn_weights = attn_weights.type(value.dtype)
        attn_weights = self.attn_dropout(attn_weights)

        # Mask heads if we want to
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_output = torch.matmul(attn_weights, value)

        return attn_output, attn_weights

    transformers.models.gpt2.modeling_gpt2.GPT2Attention._attn = new_attn

Here is how we are going to tackle the activation function: huggingface/transformers#13997

So a PR will need to be files with https://github.com/huggingface/transformers/

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions