Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GLM-4 and Later GLM Model (Draft) #31977

Closed
wants to merge 86 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
9cf74d7
add GLM-4
zRzRzRzRzRzRzR Jul 11, 2024
bef7fd9
GLM-4 FastTokenizer
zRzRzRzRzRzRzR Jul 11, 2024
c986fac
tokenizer fix
zRzRzRzRzRzRzR Jul 11, 2024
2da5d32
rename
zRzRzRzRzRzRzR Jul 11, 2024
675e7a1
pad token
zRzRzRzRzRzRzR Jul 11, 2024
304e4ef
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 11, 2024
0b241f2
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 12, 2024
fa44041
Fix past_key_values
duzx16 Jul 14, 2024
24dec6b
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 14, 2024
5d2bf5e
Merge branch 'glm-4' of github.com:zRzRzRzRzRzRzR/transformers into g…
duzx16 Jul 14, 2024
63d49c9
Fix flash attention
duzx16 Jul 14, 2024
0a5adf3
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 15, 2024
51cbf5d
add update
zRzRzRzRzRzRzR Jul 15, 2024
86b5004
Merge branch 'glm-4' of https://github.com/zRzRzRzRzRzRzR/transformer…
zRzRzRzRzRzRzR Jul 15, 2024
9a553e5
test with glm
zRzRzRzRzRzRzR Jul 15, 2024
4d45b21
fix test
zRzRzRzRzRzRzR Jul 15, 2024
85cfe41
add discription
zRzRzRzRzRzRzR Jul 15, 2024
860c7ee
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 15, 2024
c83ec2d
update glm
zRzRzRzRzRzRzR Jul 16, 2024
2608010
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 16, 2024
1719000
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 18, 2024
3f0452e
rewrite tokenizer
zRzRzRzRzRzRzR Jul 18, 2024
33d2ca3
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 19, 2024
084988e
fix some test
zRzRzRzRzRzRzR Jul 19, 2024
0cb1531
fix testing
zRzRzRzRzRzRzR Jul 19, 2024
e49718f
Fix RMSNorm initialization
duzx16 Jul 20, 2024
a362206
Fix position ids when passing input_embeds
duzx16 Jul 20, 2024
08b43d9
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 20, 2024
3c5322d
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 23, 2024
dd06993
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 24, 2024
8cc0381
Fix dtype error
duzx16 Jul 24, 2024
a35997e
Merge branch 'glm-4' of github.com:zRzRzRzRzRzRzR/transformers into g…
duzx16 Jul 24, 2024
621d32f
Fix output_layer for classification models
duzx16 Jul 24, 2024
48d1704
fix gradient
zRzRzRzRzRzRzR Jul 24, 2024
5881ed5
remove some skip test
zRzRzRzRzRzRzR Jul 24, 2024
c920ad9
fix small test
zRzRzRzRzRzRzR Jul 24, 2024
21781b3
Fix prepare_inputs_for_generation
duzx16 Jul 24, 2024
9599200
Merge branch 'glm-4' of github.com:zRzRzRzRzRzRzR/transformers into g…
duzx16 Jul 24, 2024
a9b1d0d
fix
zRzRzRzRzRzRzR Jul 25, 2024
0631615
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 25, 2024
9f33751
add converter
zRzRzRzRzRzRzR Jul 25, 2024
2663a13
fix PEP 8
zRzRzRzRzRzRzR Jul 25, 2024
aad19db
remove test
zRzRzRzRzRzRzR Jul 25, 2024
1e9183c
index
zRzRzRzRzRzRzR Jul 25, 2024
e8b90a1
fix doctested
zRzRzRzRzRzRzR Jul 25, 2024
65e1996
remove init
zRzRzRzRzRzRzR Jul 25, 2024
266ce77
fix copied error
zRzRzRzRzRzRzR Jul 25, 2024
cd9c304
fix mlp differ
zRzRzRzRzRzRzR Jul 25, 2024
ba30dad
fix copied eerror
zRzRzRzRzRzRzR Jul 25, 2024
afb1423
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 25, 2024
48aaba1
test_hidden_states_output = False
zRzRzRzRzRzRzR Jul 25, 2024
33d976f
Merge branch 'glm-4' of https://github.com/zRzRzRzRzRzRzR/transformer…
zRzRzRzRzRzRzR Jul 25, 2024
0675202
fix
zRzRzRzRzRzRzR Jul 25, 2024
19b0939
Update modeling_glm.py
zRzRzRzRzRzRzR Jul 25, 2024
b2b6c0f
Update __init__.py
zRzRzRzRzRzRzR Jul 25, 2024
6760791
fix glm type error
zRzRzRzRzRzRzR Jul 25, 2024
515d9d9
fix
zRzRzRzRzRzRzR Jul 25, 2024
9951c92
ruff problem
zRzRzRzRzRzRzR Jul 25, 2024
547ac95
Update convert_slow_tokenizer.py
zRzRzRzRzRzRzR Jul 25, 2024
9ba6cf7
Add explanations in English
zRzRzRzRzRzRzR Jul 25, 2024
9fb6405
reformate
zRzRzRzRzRzRzR Jul 25, 2024
e37bb49
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 25, 2024
25aec29
Update configuration_glm.py
zRzRzRzRzRzRzR Jul 25, 2024
58d344a
Merge branch 'glm-4' of https://github.com/zRzRzRzRzRzRzR/transformer…
zRzRzRzRzRzRzR Jul 25, 2024
073b811
fix
zRzRzRzRzRzRzR Jul 25, 2024
c0e6ae9
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 25, 2024
6ac085f
fix glm dummy
zRzRzRzRzRzRzR Jul 25, 2024
f140603
Merge branch 'glm-4' of https://github.com/zRzRzRzRzRzRzR/transformer…
zRzRzRzRzRzRzR Jul 25, 2024
65f471d
add doc
zRzRzRzRzRzRzR Jul 26, 2024
7ad819f
fix init
zRzRzRzRzRzRzR Jul 26, 2024
f86af8e
Update __init__.py
zRzRzRzRzRzRzR Jul 26, 2024
c179377
Update dummy_vision_objects.py
zRzRzRzRzRzRzR Jul 26, 2024
41338d7
add_start_docstrings
zRzRzRzRzRzRzR Jul 26, 2024
dba6d1e
fix GLM_START_DOCSTRING
zRzRzRzRzRzRzR Jul 26, 2024
82b0c7f
1
zRzRzRzRzRzRzR Jul 26, 2024
a6b6f4e
Update perf_infer_gpu_one.md
zRzRzRzRzRzRzR Jul 26, 2024
d1a5ee1
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 26, 2024
c99610e
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 27, 2024
b283adc
flash attn
zRzRzRzRzRzRzR Jul 27, 2024
4cc618e
stiil need fix rotary_emb
zRzRzRzRzRzRzR Jul 27, 2024
b476dd0
fix GLMSelfAttension
zRzRzRzRzRzRzR Jul 27, 2024
aab2386
remove _get_unpad_data
zRzRzRzRzRzRzR Jul 27, 2024
550a692
fix GLMSelfAttention
zRzRzRzRzRzRzR Jul 27, 2024
6492ac3
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 30, 2024
c3d4636
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Aug 9, 2024
70b7ff4
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix RMSNorm initialization
Fix attention mask for right padding
  • Loading branch information
duzx16 committed Jul 20, 2024
commit e49718f6fce094e34bdaa512fc5673e54855fba9
133 changes: 92 additions & 41 deletions src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss

from ...cache_utils import Cache, DynamicCache
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation.utils import ModelOutput
from ...modeling_outputs import (
BaseModelOutputWithPast,
Expand All @@ -40,6 +40,8 @@
is_flash_attn_greater_or_equal_2_10,
logging,
)

from ...modeling_attn_mask_utils import AttentionMaskConverter
from .configuration_glm import GLMConfig

if is_flash_attn_2_available():
Expand Down Expand Up @@ -80,7 +82,7 @@ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs
GLMRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
self.weight = torch.nn.Parameter(torch.ones(normalized_shape, device=device, dtype=dtype))
self.eps = eps

def forward(self, hidden_states: torch.Tensor):
Expand Down Expand Up @@ -427,14 +429,9 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask):
attention_scores = attention_scores.float()
if self.coeff is not None:
attention_scores = attention_scores * self.coeff
if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
attention_mask = torch.ones(
output_size[0], 1, output_size[2], output_size[3], device=attention_scores.device, dtype=torch.bool
)
attention_mask.tril_()
attention_mask = ~attention_mask
if attention_mask is not None:
attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_layer.shape[-2]]
attention_scores = attention_scores + causal_mask
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = attention_probs.type_as(value_layer)

Expand Down Expand Up @@ -598,8 +595,6 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask):
is_causal=True,
dropout_p=self.config.attention_dropout if self.training else 0.0)
else:
if attention_mask is not None:
attention_mask = ~attention_mask
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
attention_mask,
dropout_p=self.config.attention_dropout if self.training else 0.0)
Expand Down Expand Up @@ -659,36 +654,85 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()

def get_masks(self, input_ids, past_key_values, padding_mask=None):
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if padding_mask is not None and not padding_mask.all():
return padding_mask
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

batch_size, seq_length = input_ids.shape
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
full_attention_mask.tril_()

past_length = 0
if past_key_values:
past_length = past_key_values.get_seq_length()
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)

if past_length:
full_attention_mask = torch.cat(
(torch.ones(batch_size, seq_length, past_length, device=input_ids.device), full_attention_mask), dim=-1
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)

if padding_mask is not None:
padding_mask = padding_mask.bool() # Ensure padding_mask is a boolean tensor
expanded_padding_mask = padding_mask.unsqueeze(1).expand(-1, seq_length, -1)
full_attention_mask = full_attention_mask * expanded_padding_mask

if not past_length and padding_mask is not None:
full_attention_mask = full_attention_mask * (~padding_mask.unsqueeze(-1))

full_attention_mask = (full_attention_mask < 0.5).bool()
full_attention_mask.unsqueeze_(1)
return full_attention_mask
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask

def get_position_ids(self, input_ids, device):
batch_size, seq_length = input_ids.shape
Expand Down Expand Up @@ -989,6 +1033,8 @@ def default_init(cls, *args, **kwargs):
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
self.output_layer = init_method(nn.Linear, config.hidden_size, config.vocab_size, bias=False,
dtype=config.torch_dtype, **init_kwargs)
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embedding.word_embeddings
Expand All @@ -1008,6 +1054,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions

Expand All @@ -1026,6 +1073,8 @@ def forward(
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)

if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
Expand All @@ -1035,12 +1084,14 @@ def forward(
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
)

if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
full_attention_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)

# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
Expand Down