-
Notifications
You must be signed in to change notification settings - Fork 227
Open
Labels
Good First IssueGood for newcomersGood for newcomersGood Second IssueFor harder tasksFor harder tasks
Description
As can be seen from #121 we have a divergence between Meg and HF GPT2, while using the same weights under fp16.
So the proposed solution to enable users to use BigScience-pretrained models is to create a new architecture, which would be an identical clone of HF's GPT2, but with some changes.
Here are 3 changes:
def apply_overrides():
# 1. layer norm needs to be done in fp32 and then cast back to fp16 to match meg.
torch_layer_norm_orig = torch.layer_norm
def torch_layer_norm_force_fp32(input, normalized_shape, weight, bias, eps, cuddn):
out = torch_layer_norm_orig(input.float(), normalized_shape, weight.float(), bias.float(), eps, torch.backends.cudnn.enabled).half()
print(out)
#die
return out
torch.layer_norm = torch_layer_norm_force_fp32
# 2. MLP uses a slightly different activation function with a custom bwd
import transformers.activations
@torch.jit.script
def gelu_megatron_fwd(x):
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
@torch.jit.script
def gelu_megatron_bwd(g, x):
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
class GeLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return gelu_megatron_fwd(input)
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors
tmp = gelu_megatron_bwd(grad_output, input)
return tmp, tmp
transformers.activations.gelu_fast = GeLUFunction.apply
transformers.activations.ACT2FN["gelu_fast"] = transformers.activations.gelu_fast
# 3. torch.baddbmm() (meg) produces slightly different results than torch.matmul, so override to use `torch.baddbmm`
import transformers.models.gpt2.modeling_gpt2
from torch import nn
def new_attn(self, query, key, value, attention_mask=None, head_mask=None):
output_size = (query.size(0), key.size(1), query.size(2), key.size(2))
matmul_result = torch.empty(output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query.dtype, device=query.device)
factor = float(value.size(-1)) ** 0.5
matmul_result = torch.baddbmm(
matmul_result,
query.reshape(-1, query.shape[2], query.shape[3]), # [b * np, sq, hn]
key.reshape(-1, query.shape[2], query.shape[3]).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=1.0 / factor
)
attn_weights = matmul_result.view(*output_size)
# attn_weights = torch.matmul(query, key.transpose(-1, -2))
#
# if self.scale_attn_weights:
# attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.Softmax(dim=-1)(attn_weights)
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
transformers.models.gpt2.modeling_gpt2.GPT2Attention._attn = new_attn
Here is how we are going to tackle the activation function: huggingface/transformers#13997
So a PR will need to be files with https://github.com/huggingface/transformers/
Metadata
Metadata
Assignees
Labels
Good First IssueGood for newcomersGood for newcomersGood Second IssueFor harder tasksFor harder tasks