Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workaround for RoPE computed in bf16 for GPT-NeoX #746

Merged
merged 3 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
gaudi_gpt_neox_attention_forward,
gaudi_gpt_neox_layer_forward,
gaudi_gpt_neox_model_forward,
gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache,
gaudi_gptj_block_forward,
gaudi_gptj_model_forward,
gaudi_invert_attention_mask,
Expand Down Expand Up @@ -262,6 +263,9 @@ def adapt_transformers_to_gaudi():
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXModel.forward = gaudi_gpt_neox_model_forward
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward = gaudi_gpt_neox_layer_forward
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention.forward = gaudi_gpt_neox_attention_forward
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding._set_cos_sin_cache = (
gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache
)

# Optimization for llama generation on Gaudi
transformers.models.llama.modeling_llama.LlamaForCausalLM = GaudiLlamaForCausalLM
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
gaudi_gpt_neox_attention_forward,
gaudi_gpt_neox_layer_forward,
gaudi_gpt_neox_model_forward,
gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache,
)
from .gptj import (
GaudiGPTJAttention,
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/gpt_neox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
gaudi_gpt_neox_attention_forward,
gaudi_gpt_neox_layer_forward,
gaudi_gpt_neox_model_forward,
gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache,
)
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def gaudi_gpt_neox_attention_forward(
- add new args token_idx
- optimize KV cache
"""
# Workaround till FusedRoPE is fixed
global FusedRoPE
if self.training and FusedRoPE is not None:
FusedRoPE = None

has_layer_past = layer_past is not None

# Compute QKV
Expand Down Expand Up @@ -404,6 +409,17 @@ def prepare_inputs_for_generation(
return model_inputs


def gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()


def apply_customized_rope(q, k, cos, sin, position_ids):
if q.device.type == "hpu" and FusedRoPE:
return FusedRoPE.apply(
Expand Down
6 changes: 3 additions & 3 deletions tests/baselines/gpt_neox_20b.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
"deepspeed": {
"learning_rate": 5e-5,
"train_batch_size": 2,
"perplexity": 8.787531864839819,
"train_runtime": 670.5209,
"train_samples_per_second": 8.485,
"perplexity": 8.0545,
"train_runtime": 745.7237,
"train_samples_per_second": 7.242,
"extra_arguments": [
"--dataset_config_name wikitext-2-raw-v1",
"--gradient_checkpointing",
Expand Down
Loading