Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GLM-4 and Later GLM Model (Draft) #31977

Closed
wants to merge 86 commits into from

Conversation

zRzRzRzRzRzRzR
Copy link
Contributor

This is a draft and we will continue work

Who can review?

@ArthurZucker

@xianbaoqian
Copy link

Hi @zRzRzRzRzRzRzR ! Thanks for drafting the PR. The workflow has been failing due to the usage of TikToken. Once the converter scripts converts tiktoken configuration to HF tokenizer configuration, you won't need to import tiktoken during inference at tokenization_glm.py

@xianbaoqian xianbaoqian requested a review from ArthurZucker July 18, 2024 06:05
@ArthurZucker
Copy link
Collaborator

YEP! Will review today 🤗

@zRzRzRzRzRzRzR
Copy link
Contributor Author

Hi @zRzRzRzRzRzRzR ! Thanks for drafting the PR. The workflow has been failing due to the usage of TikToken. Once the converter scripts converts tiktoken configuration to HF tokenizer configuration, you won't need to import tiktoken during inference at tokenization_glm.py

Fix this issue now~ Tks

Fix attention mask for right padding
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am stopping the review as a LOT of the comments are still not adressed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file should be removed as we can map the GPT2Tokenizer direct and use it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here, we can use GPT2TokenizerFast!

logger = logging.get_logger(__name__)


class GLMConfig(PretrainedConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is still the issue with the camel casing!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GLM is the name of our model, not Glm. Do we need to stick to camel case in this context as well?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's the same for LLaMa which we set to Llama!

Comment on lines 27 to 28
This is the configuration class to store the configuration of a [`GLMModel`]. It is used to instantiate a Phi-3
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This is the configuration class to store the configuration of a [`GLMModel`]. It is used to instantiate a Phi-3
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
This is the configuration class to store the configuration of a [`GLMModel`]. It is used to instantiate a GLM
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix

Comment on lines 47 to 51
if is_flash_attn_2_available():
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from flash_attn import flash_attn_func, flash_attn_varlen_func

_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again this was refactored

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix now

Comment on lines 59 to 68
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix

Comment on lines 89 to 130
class GLMRotaryEmbedding(nn.Module):
def __init__(self, dim, rope_theta=1, original_impl=False, device=None, dtype=None):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
self.register_buffer("inv_freq", inv_freq)
self.dim = dim
self.original_impl = original_impl
self.rope_theta = rope_theta

def forward_impl(
self,
seq_len: int,
n_elem: int,
dtype: torch.dtype,
device: torch.device,
base: int = 10000,
):
"""Enhanced Transformer with Rotary Position Embedding.
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
"""
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
base = base * self.rope_theta
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))

# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)

# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta).float()

cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1).to(dtype=dtype)
return cache

def forward(self, max_seq_len, offset=0):
return self.forward_impl(
max_seq_len,
self.dim,
dtype=self.inv_freq.dtype,
device=self.inv_freq.device,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again same comment here, this is equivalent to LlamaRotaryEmbedidng

return tensor_list


class SelfAttention(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is still waiting!

Comment on lines +180 to +184
if self.multi_query_attention:
self.num_multi_query_groups_per_partition = self.multi_query_group_num
self.qkv_hidden_size = (
self.projection_size + 2 * self.hidden_size_per_attention_head * self.multi_query_group_num
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again same comment about GQA and MQA

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙏🏻

logger = logging.get_logger(__name__)


class GLMConfig(PretrainedConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's the same for LLaMa which we set to Llama!

@ArthurZucker
Copy link
Collaborator

Feel free to ping me again for a review!

@Cyrilvallez
Copy link
Member

Cyrilvallez commented Sep 27, 2024

BTW @zRzRzRzRzRzRzR, I took over and am currently adding the model. You can find the new PR here #33823, should be ready pretty soon

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants