Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GLM-4 and Later GLM Model (Draft) #31977

Closed
wants to merge 86 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
9cf74d7
add GLM-4
zRzRzRzRzRzRzR Jul 11, 2024
bef7fd9
GLM-4 FastTokenizer
zRzRzRzRzRzRzR Jul 11, 2024
c986fac
tokenizer fix
zRzRzRzRzRzRzR Jul 11, 2024
2da5d32
rename
zRzRzRzRzRzRzR Jul 11, 2024
675e7a1
pad token
zRzRzRzRzRzRzR Jul 11, 2024
304e4ef
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 11, 2024
0b241f2
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 12, 2024
fa44041
Fix past_key_values
duzx16 Jul 14, 2024
24dec6b
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 14, 2024
5d2bf5e
Merge branch 'glm-4' of github.com:zRzRzRzRzRzRzR/transformers into g…
duzx16 Jul 14, 2024
63d49c9
Fix flash attention
duzx16 Jul 14, 2024
0a5adf3
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 15, 2024
51cbf5d
add update
zRzRzRzRzRzRzR Jul 15, 2024
86b5004
Merge branch 'glm-4' of https://github.com/zRzRzRzRzRzRzR/transformer…
zRzRzRzRzRzRzR Jul 15, 2024
9a553e5
test with glm
zRzRzRzRzRzRzR Jul 15, 2024
4d45b21
fix test
zRzRzRzRzRzRzR Jul 15, 2024
85cfe41
add discription
zRzRzRzRzRzRzR Jul 15, 2024
860c7ee
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 15, 2024
c83ec2d
update glm
zRzRzRzRzRzRzR Jul 16, 2024
2608010
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 16, 2024
1719000
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 18, 2024
3f0452e
rewrite tokenizer
zRzRzRzRzRzRzR Jul 18, 2024
33d2ca3
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 19, 2024
084988e
fix some test
zRzRzRzRzRzRzR Jul 19, 2024
0cb1531
fix testing
zRzRzRzRzRzRzR Jul 19, 2024
e49718f
Fix RMSNorm initialization
duzx16 Jul 20, 2024
a362206
Fix position ids when passing input_embeds
duzx16 Jul 20, 2024
08b43d9
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 20, 2024
3c5322d
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 23, 2024
dd06993
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 24, 2024
8cc0381
Fix dtype error
duzx16 Jul 24, 2024
a35997e
Merge branch 'glm-4' of github.com:zRzRzRzRzRzRzR/transformers into g…
duzx16 Jul 24, 2024
621d32f
Fix output_layer for classification models
duzx16 Jul 24, 2024
48d1704
fix gradient
zRzRzRzRzRzRzR Jul 24, 2024
5881ed5
remove some skip test
zRzRzRzRzRzRzR Jul 24, 2024
c920ad9
fix small test
zRzRzRzRzRzRzR Jul 24, 2024
21781b3
Fix prepare_inputs_for_generation
duzx16 Jul 24, 2024
9599200
Merge branch 'glm-4' of github.com:zRzRzRzRzRzRzR/transformers into g…
duzx16 Jul 24, 2024
a9b1d0d
fix
zRzRzRzRzRzRzR Jul 25, 2024
0631615
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 25, 2024
9f33751
add converter
zRzRzRzRzRzRzR Jul 25, 2024
2663a13
fix PEP 8
zRzRzRzRzRzRzR Jul 25, 2024
aad19db
remove test
zRzRzRzRzRzRzR Jul 25, 2024
1e9183c
index
zRzRzRzRzRzRzR Jul 25, 2024
e8b90a1
fix doctested
zRzRzRzRzRzRzR Jul 25, 2024
65e1996
remove init
zRzRzRzRzRzRzR Jul 25, 2024
266ce77
fix copied error
zRzRzRzRzRzRzR Jul 25, 2024
cd9c304
fix mlp differ
zRzRzRzRzRzRzR Jul 25, 2024
ba30dad
fix copied eerror
zRzRzRzRzRzRzR Jul 25, 2024
afb1423
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 25, 2024
48aaba1
test_hidden_states_output = False
zRzRzRzRzRzRzR Jul 25, 2024
33d976f
Merge branch 'glm-4' of https://github.com/zRzRzRzRzRzRzR/transformer…
zRzRzRzRzRzRzR Jul 25, 2024
0675202
fix
zRzRzRzRzRzRzR Jul 25, 2024
19b0939
Update modeling_glm.py
zRzRzRzRzRzRzR Jul 25, 2024
b2b6c0f
Update __init__.py
zRzRzRzRzRzRzR Jul 25, 2024
6760791
fix glm type error
zRzRzRzRzRzRzR Jul 25, 2024
515d9d9
fix
zRzRzRzRzRzRzR Jul 25, 2024
9951c92
ruff problem
zRzRzRzRzRzRzR Jul 25, 2024
547ac95
Update convert_slow_tokenizer.py
zRzRzRzRzRzRzR Jul 25, 2024
9ba6cf7
Add explanations in English
zRzRzRzRzRzRzR Jul 25, 2024
9fb6405
reformate
zRzRzRzRzRzRzR Jul 25, 2024
e37bb49
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 25, 2024
25aec29
Update configuration_glm.py
zRzRzRzRzRzRzR Jul 25, 2024
58d344a
Merge branch 'glm-4' of https://github.com/zRzRzRzRzRzRzR/transformer…
zRzRzRzRzRzRzR Jul 25, 2024
073b811
fix
zRzRzRzRzRzRzR Jul 25, 2024
c0e6ae9
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 25, 2024
6ac085f
fix glm dummy
zRzRzRzRzRzRzR Jul 25, 2024
f140603
Merge branch 'glm-4' of https://github.com/zRzRzRzRzRzRzR/transformer…
zRzRzRzRzRzRzR Jul 25, 2024
65f471d
add doc
zRzRzRzRzRzRzR Jul 26, 2024
7ad819f
fix init
zRzRzRzRzRzRzR Jul 26, 2024
f86af8e
Update __init__.py
zRzRzRzRzRzRzR Jul 26, 2024
c179377
Update dummy_vision_objects.py
zRzRzRzRzRzRzR Jul 26, 2024
41338d7
add_start_docstrings
zRzRzRzRzRzRzR Jul 26, 2024
dba6d1e
fix GLM_START_DOCSTRING
zRzRzRzRzRzRzR Jul 26, 2024
82b0c7f
1
zRzRzRzRzRzRzR Jul 26, 2024
a6b6f4e
Update perf_infer_gpu_one.md
zRzRzRzRzRzRzR Jul 26, 2024
d1a5ee1
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 26, 2024
c99610e
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 27, 2024
b283adc
flash attn
zRzRzRzRzRzRzR Jul 27, 2024
4cc618e
stiil need fix rotary_emb
zRzRzRzRzRzRzR Jul 27, 2024
b476dd0
fix GLMSelfAttension
zRzRzRzRzRzRzR Jul 27, 2024
aab2386
remove _get_unpad_data
zRzRzRzRzRzRzR Jul 27, 2024
550a692
fix GLMSelfAttention
zRzRzRzRzRzRzR Jul 27, 2024
6492ac3
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Jul 30, 2024
c3d4636
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Aug 9, 2024
70b7ff4
Merge branch 'huggingface:main' into glm-4
zRzRzRzRzRzRzR Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def _config_to_kwargs(args):
return common_kwargs


# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
Expand Down Expand Up @@ -130,7 +129,6 @@ def forward(self, max_seq_len, offset=0):
)



def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
Expand Down Expand Up @@ -269,7 +267,6 @@ def forward(
# adjust key and value for inference
if past_key_value is not None:
key_layer, value_layer = past_key_value.update(key_layer, value_layer, self.layer_number - 1)

if self.multi_query_attention:
key_layer = key_layer.unsqueeze(2)
key_layer = key_layer.expand(
Expand All @@ -285,7 +282,6 @@ def forward(
value_layer = value_layer.contiguous().view(
value_layer.size()[:1] + (self.num_attention_heads_per_partition,) + value_layer.size()[3:]
)

# ==================================
# core attention computation
# ==================================
Expand Down Expand Up @@ -456,7 +452,6 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask):
# [b, sq, np, hn] --> [b, sq, hp]
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)

return context_layer


Expand Down Expand Up @@ -590,13 +585,19 @@ class GLMSdpaAttention(GLMAttention):

def forward(self, query_layer, key_layer, value_layer, attention_mask):
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
is_causal=True,
dropout_p=self.config.attention_dropout if self.training else 0.0)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
is_causal=True,
dropout_p=self.config.attention_dropout if self.training else 0.0)
else:
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
attention_mask,
dropout_p=self.config.attention_dropout if self.training else 0.0)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attention_mask,
dropout_p=self.config.attention_dropout if self.training else 0.0)
context_layer = context_layer.transpose(1, 2).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
Expand Down
109 changes: 54 additions & 55 deletions tests/models/glm/test_modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@
import gc
import tempfile
import unittest
from parameterized import parameterized

import pytest

from transformers import AutoTokenizer, GLMConfig, is_torch_available, set_seed
from transformers import AutoTokenizer, GLMConfig, is_torch_available
from transformers.testing_utils import (
backend_empty_cache,
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
Expand Down Expand Up @@ -60,7 +58,7 @@ def __init__(
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
hidden_size=8,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
Expand Down Expand Up @@ -394,61 +392,61 @@ def test_GLM_token_classification_model(self):
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)

@unittest.skip(reason="GLM buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self):
pass

@unittest.skip(reason="GLM uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass

@unittest.skip(reason="SQRBound is known to have issues with gc")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

def _check_attentions_for_generate(self, *args, **kwargs):
return True # Model does not return attention

@unittest.skip(reason="Past key values are not returned")
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass

@unittest.skip(reason="Past key values are not returned")
def test_model_parallelism(self):
pass
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()

@unittest.skip(reason="Past key values are not returned")
def test_model_parallel_beam_search(self):
pass
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))

hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states

expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)

## GLM block start with id 1 not 0
self.assertEqual(len(hidden_states), expected_num_layers + 1)

if hasattr(self.model_tester, "encoder_seq_length"):
seq_length = self.model_tester.encoder_seq_length
if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1:
seq_length = seq_length * self.model_tester.chunk_length
else:
seq_length = self.model_tester.seq_length

self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)

if config.is_encoder_decoder:
hidden_states = outputs.decoder_hidden_states
self.assertIsInstance(hidden_states, (list, tuple))
self.assertEqual(len(hidden_states), expected_num_layers + 1)
seq_len = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)

self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[decoder_seq_length, self.model_tester.hidden_size],
)

def _check_past_key_values_for_generate(self, *args, **kwargs):
return True
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

@unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported")
def test_assisted_decoding_matches_greedy_search(self):
pass
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)

@unittest.skip(reason="Relies on `past_key_values` returned by the model. Not supported with recurrent GLM")
def test_assisted_decoding_sample(self):
pass
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True

@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass
check_hidden_states_output(inputs_dict, config, model_class)

@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass

@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_retain_grad_hidden_states_attentions(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
Expand Down Expand Up @@ -483,7 +481,6 @@ def test_flash_attn_2_generate_padding_right(self):
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)


@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
Expand Down Expand Up @@ -529,15 +526,13 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)


@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="GLM flash attention does not support right padding")


@slow
@require_torch
class GLMIntegrationTest(unittest.TestCase):
Expand Down Expand Up @@ -584,3 +579,7 @@ def test_glm_instruct_generation(self):
"[gMASK] <sop> <|system|> \nYou are a helpful digital assistant. Please provide safe, ethical and accurate information to the user. <|user|> \nTell me the answer of 1 plus 1? <|assistant|> \nThe answer to 1 plus 1 is 2. <|user|>"
]
self.assertListEqual(output_text, EXPECTED_OUTPUT)

@unittest.skip(reason="Gemma uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass