Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable internal kv bucket in llama #720

Merged
merged 4 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ python run_generation.py \
`--bucket_size` option is especially useful when processing an input stream with varying lengths, that is when you have something like `--dataset_name squad --column_name context --max_input_tokens -1`. `--max_input_tokens -1` specifies no truncation of input prompt in the dataset.

Another way to simulate dynamic input is to use `--simulate_dyn_prompt`. For example `--simulate_dyn_prompt 25,35,45` will extend or crop the default prompt (or the prompt passed in using `--prompt`) to sizes 25, 35, and 45, and throughput will be measured for these 3 lengths. If `--simulate_dyn_prompt` is used, the min and max input lengths from it are computed to perform warmup as well. One final optimization that can be used in case of dynamic inputs is `--reduce_recompile`. Thus the suggested configuration to simulate dynamicity after warmup is to use all three arguments: `--simulate_dyn_prompt 25 35 45 --reduce_recompile --bucket_size 30`

While `--bucket_size` works for any model without model file changes, an even more optimized version of bucketing is supported for certain models like Llama. This can be enabled by setting `--bucket_internal` flag (along with `--bucket_size` to specify the bucket size)

### Running with FP8

Llama2-70b and Llama2-7b in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch.
Expand Down
5 changes: 5 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ def setup_parser(parser):
then we use `shape = prompt_length + max_new_tokens`. If a positive number is passed \
we increase the bucket in steps of `bucket_size` instead of allocating to max (`prompt_length + max_new_tokens`).",
)
parser.add_argument(
"--bucket_internal",
action="store_true",
help="Split kv sequence into buckets in decode phase. It improves throughput when max_new_tokens is large.",
)
parser.add_argument(
"--dataset_max_samples",
default=-1,
Expand Down
1 change: 1 addition & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def setup_generation_config(args, model, tokenizer):
generation_config.use_cache = args.use_kv_cache
generation_config.static_shapes = is_optimized
generation_config.bucket_size = args.bucket_size if is_optimized else -1
generation_config.bucket_internal = args.bucket_internal
generation_config.do_sample = args.do_sample
generation_config.num_beams = args.num_beams
generation_config.bad_words_ids = bad_words_ids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, **kwargs):
self.limit_hpu_graphs = kwargs.get("limit_hpu_graphs", None)
self.reuse_cache = kwargs.get("reuse_cache", None)
self.bucket_size = kwargs.get("bucket_size", -1)
self.bucket_internal = kwargs.get("bucket_internal", None)
self.reduce_recompile = kwargs.get("reduce_recompile", None)
self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None)
self.use_flash_attention = kwargs.get("use_flash_attention", None)
Expand Down
50 changes: 37 additions & 13 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,19 +588,27 @@ def generate(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
)

is_greedy_or_beam_and_bucket = generation_config.bucket_size > 0 and (
self._get_generation_mode(generation_config, assistant_model) == GenerationMode.GREEDY_SEARCH
or self._get_generation_mode(generation_config, assistant_model) == GenerationMode.BEAM_SEARCH
is_greedy_or_beam_and_bucket = (
not generation_config.bucket_internal
and generation_config.bucket_size > 0
and (
self._get_generation_mode(generation_config, assistant_model) == GenerationMode.GREEDY_SEARCH
or self._get_generation_mode(generation_config, assistant_model) == GenerationMode.BEAM_SEARCH
)
)
model_kwargs["bucket_size"] = generation_config.bucket_size if generation_config.static_shapes else -1
model_kwargs["bucket_internal"] = generation_config.bucket_internal
model_kwargs["reduce_recompile"] = (
generation_config.reduce_recompile if generation_config.reduce_recompile is not None else False
)
if model_kwargs["reduce_recompile"]:
assert generation_config.bucket_size
if generation_config.reuse_cache:
assert self.config.model_type in ["llama"], "reuse_cache only supported by llama at the moment"
assert generation_config.bucket_size <= 0, "reuse_cache and bucketing flags set together"
if not generation_config.bucket_internal:
assert generation_config.bucket_size <= 0, "reuse_cache and bucketing flags set together"
else:
assert generation_config.bucket_size >= 0, "bucket_internal and bucket_size flags set together"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we are in the case where generation_config.bucket_internal is True, so if this assert fails (i.e. generation_config.bucket_size < 0), it means that bucket_size is not set right? But the error message says otherwise

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@puneeshkhanna I see you've corrected some error messages. I hope this update won't cause conflict.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xt574chen - I will update my PR once this gets merged first


if generation_config.static_shapes:
# Pad inputs to have static shapes during generation, this gives better performance than dynamic shapes on HPUs
Expand Down Expand Up @@ -713,6 +721,8 @@ def generate(
token_idx,
generation_config.kv_cache_fp8,
)
model_kwargs["kv_cache_len"] = calculated_max_length

if self.config.model_type in ["llama"]:
if self.config.max_position_embeddings < calculated_max_length:
unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length)
Expand Down Expand Up @@ -1369,12 +1379,15 @@ def greedy_search(
this_peer_finished = False # used by synced_gpus only
bucket_size = model_kwargs.get("bucket_size", -1)
reduce_recompile = model_kwargs.get("reduce_recompile", False)
prev_idx = None # avoiding calculate cache_idx when its value is not changing
bucket_internal = model_kwargs.get("bucket_internal", None)

prompt_len = input_ids.shape[-1]
if bucket_size >= 0:
inc = iter(incrementor(bucket_size, prompt_len))
if bucket_size > 0:
assert "position_ids" not in model_kwargs, "Untested path"
if not bucket_internal:
if bucket_size >= 0:
inc = iter(incrementor(bucket_size, prompt_len))
if bucket_size > 0:
assert "position_ids" not in model_kwargs, "Untested path"

while True:
if lazy_mode:
Expand All @@ -1391,11 +1404,22 @@ def greedy_search(
break

if bucket_size > 0:
# it will not have been padded if bucket_size > 0
params = next(inc)
input_ids, model_kwargs = self.update_model_kwargs_for_bucketing(
params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile
)
if not bucket_internal:
# it will not have been padded if bucket_size > 0
params = next(inc)
input_ids, model_kwargs = self.update_model_kwargs_for_bucketing(
params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile
)
else:
# Calculate slice idx for kv cache. Breaking down the kv cache in the attention block helps to reduce computation time.
if model_kwargs.get("token_idx") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size:
idx = torch.div(model_kwargs.get("token_idx") - 1, bucket_size, rounding_mode="floor")
if idx != prev_idx:
cache_idx = (idx.item() + 1) * bucket_size
model_kwargs["cache_idx"] = cache_idx
prev_idx = idx
else:
model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"]

# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
Expand Down
17 changes: 17 additions & 0 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def pre_attn_forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Expand Down Expand Up @@ -281,6 +282,12 @@ def pre_attn_forward(
key_states = update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
value_states = update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len)

if cache_idx is not None and q_len == 1:
key_states = key_states[:, :, :cache_idx, :]
value_states = value_states[:, :, :cache_idx, :]
attention_mask = attention_mask[:, :, :, :cache_idx]
kv_seq_len = key_states.shape[-2]

if use_cache:
if reuse_cache:
past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
Expand Down Expand Up @@ -445,6 +452,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand Down Expand Up @@ -474,6 +482,7 @@ def forward(
reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
**kwargs,
)
self.self_attn.attention_all_reduce(output_pre_attn)
Expand Down Expand Up @@ -503,6 +512,7 @@ def pre_attn(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
hidden_states = self.input_layernorm(hidden_states)
output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward(
Expand All @@ -517,6 +527,7 @@ def pre_attn(
reuse_cache,
use_flash_attention,
flash_attention_recompute,
cache_idx=cache_idx,
)
return output_attn, attn_weights, present_key_value

Expand Down Expand Up @@ -565,6 +576,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
Expand Down Expand Up @@ -681,6 +693,7 @@ def forward(
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -728,6 +741,7 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM):

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8):
self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8)
self.kv_cache_len = max_seq_len

def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return self.model.reorder_kv_cache(beam_idx)
Expand All @@ -753,6 +767,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -775,6 +790,7 @@ def forward(
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
)
hidden_states = outputs[0]
_, seq_len, _ = hidden_states.shape
Expand Down Expand Up @@ -886,6 +902,7 @@ def prepare_inputs_for_generation(
"reuse_cache": reuse_cache,
"use_flash_attention": kwargs.get("use_flash_attention"),
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"cache_idx": kwargs.get("cache_idx"),
}
)
return model_inputs
Expand Down
Loading