Skip to content

Commit

Permalink
Enable internal kv bucket in llama (huggingface#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
xt574chen authored and Jinyan chen committed Feb 27, 2024
1 parent a8a1efd commit 82b6478
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 0 deletions.
3 changes: 3 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ python run_generation.py \
`--bucket_size` option is especially useful when processing an input stream with varying lengths, that is when you have something like `--dataset_name squad --column_name context --max_input_tokens -1`. `--max_input_tokens -1` specifies no truncation of input prompt in the dataset.

Another way to simulate dynamic input is to use `--simulate_dyn_prompt`. For example `--simulate_dyn_prompt 25,35,45` will extend or crop the default prompt (or the prompt passed in using `--prompt`) to sizes 25, 35, and 45, and throughput will be measured for these 3 lengths. If `--simulate_dyn_prompt` is used, the min and max input lengths from it are computed to perform warmup as well. One final optimization that can be used in case of dynamic inputs is `--reduce_recompile`. Thus the suggested configuration to simulate dynamicity after warmup is to use all three arguments: `--simulate_dyn_prompt 25 35 45 --reduce_recompile --bucket_size 30`

While `--bucket_size` works for any model without model file changes, an even more optimized version of bucketing is supported for certain models like Llama. This can be enabled by setting `--bucket_internal` flag (along with `--bucket_size` to specify the bucket size)

### Running with FP8

Llama2-70b and Llama2-7b in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch.
Expand Down
5 changes: 5 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ def setup_parser(parser):
then we use `shape = prompt_length + max_new_tokens`. If a positive number is passed \
we increase the bucket in steps of `bucket_size` instead of allocating to max (`prompt_length + max_new_tokens`).",
)
parser.add_argument(
"--bucket_internal",
action="store_true",
help="Split kv sequence into buckets in decode phase. It improves throughput when max_new_tokens is large.",
)
parser.add_argument(
"--dataset_max_samples",
default=-1,
Expand Down
1 change: 1 addition & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def setup_generation_config(args, model, tokenizer):
generation_config.use_cache = args.use_kv_cache
generation_config.static_shapes = is_optimized
generation_config.bucket_size = args.bucket_size if is_optimized else -1
generation_config.bucket_internal = args.bucket_internal
generation_config.do_sample = args.do_sample
generation_config.num_beams = args.num_beams
generation_config.bad_words_ids = bad_words_ids
Expand Down
3 changes: 3 additions & 0 deletions optimum/habana/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class GaudiGenerationConfig(GenerationConfig):
If negative (default=-1) pad to max if `static_shapes` is set. Else start with
`shape = bucket_size * ceil(prompt_len/bucket_size)` and then grow space by `bucket_size` when needed.
Only active if `static_shapes` is used. Can't be used with `reuse_cache`.
bucket_internal (`bool`, *optional*):
Split kv sequence into buckets in decode phase. It improves throughput when max_new_tokens is large.
kv_cache_fp8 (`bool`, *optional*):
Store kv-cache in float8 when kv-cache is used
use_flash_attention (`bool`, *optional*):
Expand All @@ -44,6 +46,7 @@ def __init__(self, **kwargs):
self.limit_hpu_graphs = kwargs.get("limit_hpu_graphs", None)
self.reuse_cache = kwargs.get("reuse_cache", None)
self.bucket_size = kwargs.get("bucket_size", -1)
self.bucket_internal = kwargs.get("bucket_internal", None)
self.reduce_recompile = kwargs.get("reduce_recompile", None)
self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None)
self.use_flash_attention = kwargs.get("use_flash_attention", None)
Expand Down
17 changes: 17 additions & 0 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def pre_attn_forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Expand Down Expand Up @@ -281,6 +282,12 @@ def pre_attn_forward(
key_states = update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
value_states = update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len)

if cache_idx is not None and q_len == 1:
key_states = key_states[:, :, :cache_idx, :]
value_states = value_states[:, :, :cache_idx, :]
attention_mask = attention_mask[:, :, :, :cache_idx]
kv_seq_len = key_states.shape[-2]

if use_cache:
if reuse_cache:
past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
Expand Down Expand Up @@ -445,6 +452,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand Down Expand Up @@ -474,6 +482,7 @@ def forward(
reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
**kwargs,
)
self.self_attn.attention_all_reduce(output_pre_attn)
Expand Down Expand Up @@ -503,6 +512,7 @@ def pre_attn(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
hidden_states = self.input_layernorm(hidden_states)
output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward(
Expand All @@ -517,6 +527,7 @@ def pre_attn(
reuse_cache,
use_flash_attention,
flash_attention_recompute,
cache_idx=cache_idx,
)
return output_attn, attn_weights, present_key_value

Expand Down Expand Up @@ -565,6 +576,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
Expand Down Expand Up @@ -681,6 +693,7 @@ def forward(
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -728,6 +741,7 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM):

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8):
self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8)
self.kv_cache_len = max_seq_len

def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return self.model.reorder_kv_cache(beam_idx)
Expand All @@ -753,6 +767,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -775,6 +790,7 @@ def forward(
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
)
hidden_states = outputs[0]
_, seq_len, _ = hidden_states.shape
Expand Down Expand Up @@ -886,6 +902,7 @@ def prepare_inputs_for_generation(
"reuse_cache": reuse_cache,
"use_flash_attention": kwargs.get("use_flash_attention"),
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"cache_idx": kwargs.get("cache_idx"),
}
)
return model_inputs
Expand Down

0 comments on commit 82b6478

Please sign in to comment.