Skip to content

Commit

Permalink
Llama: fix batched generation (#29109)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored Feb 20, 2024
1 parent ff76e7c commit 7d312ad
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
33 changes: 30 additions & 3 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,34 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

@property
def sin_cached(self):
logger.warning_once(
"The sin_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead."
)
return self._sin_cached

@property
def cos_cached(self):
logger.warning_once(
"The cos_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead."
)
return self._cos_cached

def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
freqs = (self.inv_freq[:, None].float().expand(-1, position_ids.shape[0]) @ (position_ids.float())).t()
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
cos = emb.cos().to(dtype=x.dtype)
sin = emb.sin().to(dtype=x.dtype)
# backwards compatibility
self._cos_cached = cos
self._sin_cached = sin
return cos, sin


class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
Expand Down Expand Up @@ -181,6 +204,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
Expand Down Expand Up @@ -1033,6 +1058,7 @@ def _update_causal_mask(self, attention_mask, input_tensor):

batch_size, seq_length = input_tensor.shape[:2]
dtype = input_tensor.dtype
device = input_tensor.device

# support going beyond cached `max_position_embedding`
if seq_length > self.causal_mask.shape[-1]:
Expand All @@ -1048,8 +1074,9 @@ def _update_causal_mask(self, attention_mask, input_tensor):
(self.config.max_position_embeddings, self.config.max_position_embeddings),
fill_value=torch.finfo(dtype).min,
)
causal_mask = torch.triu(mask, diagonal=1).to(dtype)
causal_mask = torch.triu(mask, diagonal=1)

causal_mask = causal_mask.to(dtype=dtype, device=device)
if attention_mask is not None and attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def test_sink_cache_iterative_prompts(self):
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
def test_static_cache_greedy_sampling_pad_left(self, attn_implementation):
EXPECTED_GENERATION = [
"The best color is the one that complements the subject you are photograph",
"The best color is the one that complements the skin tone of the",
"We should not undermind the issues at hand.\nWe should not undermind the issues",
]

Expand Down Expand Up @@ -333,18 +333,18 @@ def test_static_cache_greedy_sampling_pad_left(self, attn_implementation):
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
def test_static_cache_greedy_sampling_pad_right(self, attn_implementation):
EXPECTED_GENERATION = [
"The best color is\n\n\n\n\n\n\n\n\n\n",
"We should not undermind the issues at hand, but address them head on.\nI think",
"The best color isЋ the one that complements the skin tone of",
"We should not undermind the issues at hand.\nWe should not undermind the issues",
]

tokenizer = AutoTokenizer.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
"NousResearch/Llama-2-7b-chat-hf", padding_side="right", pad_token="<s>"
)
model = AutoModelForCausalLM.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf",
torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation,
).to("cuda:1")
).to(torch_device)
inputs = tokenizer(
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
).to(model.device)
Expand Down

0 comments on commit 7d312ad

Please sign in to comment.