Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support beam search with reuse_cache and bucket_internal #1472

Merged
merged 10 commits into from
Nov 28, 2024
24 changes: 24 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,30 @@ Another way to simulate dynamic input is to use `--simulate_dyn_prompt`. For exa
While `--bucket_size` works for any model without model file changes, an even more optimized version of bucketing is supported for certain models like Llama. This can be enabled by setting `--bucket_internal` flag (along with `--bucket_size` to specify the bucket size)


### Using Beam Search

Restriction: Currently beam search is only enabled for the models with model type of `llama` or `qwen2` if `reuse_cache` is not enabled. The group beam search and constrained beam search is not supported by optimum-habana yet.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Restriction: Currently beam search is only enabled for the models with model type of `llama` or `qwen2` if `reuse_cache` is not enabled. The group beam search and constrained beam search is not supported by optimum-habana yet.
> Restriction: Currently beam search is only enabled for the models with model type of `llama` or `qwen2` if `reuse_cache` is not enabled. The group beam search and constrained beam search is not supported by optimum-habana yet.

Don't you mean "if reuse_cache is enabled" ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I will change the description here. Currently if reuse_cache is not enabled, then KV cache would be maintained by the generation, not the model, thus beam search requires the modeling containing the _reorder_cache function to be called from generation. I only added such function to the modeling of LLaMA and Qwen2.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay I see, thanks for clarifying!


Here is an example:
```bash
python run_generation.py \
--model_name_or_path Qwen/Qwen2-7b-Instruct \
--use_hpu_graphs \
--use_kv_cache \
--trim_logits \
--use_flash_attention \
--max_input_tokens 128 \
--max_new_tokens 128 \
--batch_size 4 \
--limit_hpu_graphs \
--reuse_cache \
--bucket_internal \
--bucket_size 128 \
--bf16 \
--num_beams 3
```


### Running with torch.compile

> [!NOTE]
Expand Down
108 changes: 77 additions & 31 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2736,15 +2736,26 @@ def _beam_search(
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))

# Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1
# non eos token per beam.
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
num_selection = max(2, 1 + n_eos_tokens)

if self.generation_config.static_shapes:
beam_trace_scores = torch.zeros(
(input_ids.shape[1], 2 * batch_size * num_beams), device=input_ids.device, dtype=torch.float32
(input_ids.shape[1], num_selection * batch_size * num_beams),
device=input_ids.device,
dtype=torch.float32,
)
beam_trace_indices = torch.zeros(
(input_ids.shape[1], 2 * batch_size * num_beams), device=input_ids.device, dtype=torch.int64
(input_ids.shape[1], num_selection * batch_size * num_beams),
device=input_ids.device,
dtype=torch.int64,
)
beam_trace_tokens = torch.zeros(
(input_ids.shape[1], 2 * batch_size * num_beams), device=input_ids.device, dtype=torch.int64
(input_ids.shape[1], num_selection * batch_size * num_beams),
device=input_ids.device,
dtype=torch.int64,
)
beam_trace_idx = torch.tensor(0, device=input_ids.device)
num_eos_tokens = torch.zeros((1), device=input_ids.device, dtype=torch.int64)
Expand All @@ -2753,7 +2764,7 @@ def _beam_search(
def finalize_beams(initial_ids, beam_trace, model_config, length_penalty):
beam_trace_idx, beam_trace_scores, beam_trace_indices, beam_trace_tokens = beam_trace
bs = initial_ids.shape[0]
num_beams = beam_trace_scores.shape[1] // (2 * bs)
num_beams = beam_trace_scores.shape[1] // (num_selection * bs)

beam_trace_idx = beam_trace_idx.item()
beam_trace_scores = beam_trace_scores[:beam_trace_idx, :]
Expand All @@ -2764,11 +2775,12 @@ def finalize_beams(initial_ids, beam_trace, model_config, length_penalty):
root = (float("-inf"), None, None, False)

def resolve_beam(beam):
if beam == root:
return []
score, prev, tok, is_finished = beam
rest = resolve_beam(prev)
rest.append(tok)
rest = []
while beam != root:
score, prev, tok, is_finished = beam
rest.append(tok)
beam = prev
rest.reverse()
return rest

prev_beams = [[root] * num_beams] * bs
Expand All @@ -2782,8 +2794,8 @@ def beam_score(beam):
):
cur_beams = [[] for _ in range(bs)]
for idx, (s, i, t) in enumerate(zip(scores, indices, tokens)):
batch = idx // (num_beams * 2)
idx = idx % (num_beams * 2)
batch = idx // (num_beams * num_selection)
idx = idx % (num_beams * num_selection)
b_len = 1 + step
b_score = s.item() / (b_len**length_penalty)
b_tok = t.item()
Expand Down Expand Up @@ -2830,12 +2842,17 @@ def expand_if_needed(tensor, new_size, value, dim=-1):
this_peer_finished = False

bucket_size = model_kwargs.get("bucket_size", -1)
prev_idx = -1 # avoiding calculate cache_idx when its value is not changing
bucket_internal = model_kwargs.get("bucket_internal", None)
reduce_recompile = model_kwargs.get("reduce_recompile", False)
prompt_len = input_ids.shape[-1]
if bucket_size >= 0:
inc = iter(incrementor(bucket_size, prompt_len))
if bucket_size > 0:
assert "position_ids" not in model_kwargs, "Untested path"

if not bucket_internal:
if bucket_size >= 0:
inc = iter(incrementor(bucket_size, cur_len))
if bucket_size > 0 and "position_ids" in model_kwargs:
logger.warning("Untested path for bucketing with position_ids")

if self.generation_config.static_shapes:
initial_ids = input_ids[::num_beams, 0:cur_len]

Expand All @@ -2844,7 +2861,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1):
if lazy_mode:
self.htcore_generation.mark_step()

if bucket_size > 0:
if bucket_size > 0 and not bucket_internal:
# it will not have been padded if bucket_size > 0
params = next(inc)
input_ids, model_kwargs = self.update_model_kwargs_for_bucketing(
Expand Down Expand Up @@ -2879,16 +2896,13 @@ def expand_if_needed(tensor, new_size, value, dim=-1):
)

inputs_per_sub_batches = _split_model_inputs(
model_inputs,
split_size=batch_size,
full_batch_size=batch_beam_size,
config=self.config.get_text_config(),
model_inputs, split_size=batch_size, full_batch_size=batch_beam_size
)
outputs_per_sub_batch = [
self(**inputs_per_sub_batch, return_dict=True) for inputs_per_sub_batch in inputs_per_sub_batches
]

outputs = stack_model_outputs(outputs_per_sub_batch, self.config.get_text_config())
outputs = stack_model_outputs(outputs_per_sub_batch)
else:
hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs)
outputs = self(
Expand Down Expand Up @@ -2951,10 +2965,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1):
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

# Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1
# non eos token per beam.
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
n_tokens_to_keep = max(2, 1 + n_eos_tokens) * num_beams
n_tokens_to_keep = num_selection * num_beams
if do_sample:
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=n_tokens_to_keep)
Expand Down Expand Up @@ -3038,6 +3049,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1):
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

if model_kwargs.get("past_key_values", None) is not None:
if model_kwargs["reuse_cache"]:
model_kwargs["past_key_values"] = unwrap_deepspeed_model(self).reorder_kv_cache(beam_idx)
Expand All @@ -3046,17 +3058,21 @@ def expand_if_needed(tensor, new_size, value, dim=-1):
model_kwargs["past_key_values"], beam_idx
)

# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
# (that way the memory peak does not include outputs.logits)
del outputs

if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

# increase cur_len
cur_len = cur_len + 1
if bucket_size > 0 and bucket_internal:
# Calculate slice idx for kv cache during the decode phase.
# Breaking down the kv cache in the attention block helps to reduce computation time.
if model_kwargs.get("token_idx_cpu") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size:
idx = (model_kwargs.get("token_idx_cpu") - 1) // bucket_size
if prev_idx != idx:
model_kwargs["cache_idx"] = (idx + 1) * bucket_size
prev_idx = idx
else:
model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"]

hb_profer.step()
if self.generation_config.static_shapes:
Expand All @@ -3076,13 +3092,43 @@ def expand_if_needed(tensor, new_size, value, dim=-1):
):
this_peer_finished = True

hb_profer.step()
if hb_gen_time is not None:
if not time_to_first_token_done:
time_to_first_token_done = True
import habana_frameworks.torch.hpu as torch_hpu

torch_hpu.synchronize()
hb_gen_time.step()

if (
not model_kwargs.get("pad_done", False)
and not model_kwargs.get("reuse_cache", False)
and bucket_internal
):
# Pad the returned past key values tensors from prefill phase forward run to maximum length
# before starting the decode phase.
if outputs.past_key_values[0][0].shape[2] == model_inputs["input_ids"].shape[1]:
self._pad_past_key_values(model_kwargs)
model_kwargs["pad_done"] = True

# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
# (that way the memory peak does not include outputs.logits)
del outputs

if (
model_kwargs.get("use_hpu_graphs", False)
and model_kwargs.get("limit_hpu_graphs", False)
and not model_kwargs.get("reuse_cache", False)
and bucket_internal
):
# Clear HPU graphs input tensors of the decode phase after the full generation while loop
self.clear_inputs()
# Delete past key value tensors
self._remove_past_key_values(model_kwargs)

hb_profer.stop()

if self.generation_config.static_shapes:
Expand Down
19 changes: 19 additions & 0 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,6 +1412,25 @@ def forward(
attentions=outputs.attentions,
)

@staticmethod
def _reorder_cache(
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.

Output shares the same memory storage as `past`.
"""
return tuple(
(
layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)),
layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)),
)
for layer_past in past
)

def prepare_inputs_for_generation(
self,
input_ids,
Expand Down
19 changes: 19 additions & 0 deletions optimum/habana/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,25 @@ def forward(
attentions=outputs.attentions,
)

@staticmethod
def _reorder_cache(
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.

Output shares the same memory storage as `past`.
"""
return tuple(
(
layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)),
layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)),
)
for layer_past in past
)

def prepare_inputs_for_generation(
self,
input_ids,
Expand Down