Skip to content

Commit

Permalink
Toggling KV-caches (#1763)
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi authored Oct 20, 2024
1 parent 3ca0d30 commit 73aa126
Show file tree
Hide file tree
Showing 14 changed files with 569 additions and 109 deletions.
3 changes: 3 additions & 0 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ These are utilities that are common to and can be used by all modules.
:nosignatures:

common_utils.reparametrize_as_dtype_state_dict_post_hook
common_utils.local_kv_cache
common_utils.disable_kv_cache
common_utils.delete_kv_caches


Vision Transforms
Expand Down
108 changes: 49 additions & 59 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch

from lm_eval.evaluator import evaluate, get_task_list
from lm_eval.evaluator import evaluate
from lm_eval.models.hf_vlms import HFMultimodalLM
from lm_eval.models.huggingface import HFLM
from lm_eval.tasks import get_task_dict, TaskManager
Expand All @@ -29,6 +29,7 @@
)
from torchtune.generation import generate, sample
from torchtune.modules import TransformerDecoder
from torchtune.modules.common_utils import local_kv_cache
from torchtune.modules.model_fusion import DeepFusionModel
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform
Expand Down Expand Up @@ -224,18 +225,11 @@ def _model_multimodal_generate(
"multimodal generation."
)

# 2. Setup KV cache and masks for bsz 1
encoder_max_seq_len = (
self.model_transform.image_seq_len * self._max_images_per_sample
)
# Setup masks for bsz 1
with self.device:
if self.model.caches_are_enabled():
self.model.reset_caches()
else:
self.model.setup_caches(
batch_size=1,
dtype=self._dtype,
encoder_max_seq_len=self.model_transform.image_seq_len
* self._max_images_per_sample,
decoder_max_seq_len=self.max_length,
)
causal_mask = torch.tril(
torch.ones(
size=(self.max_length, self.max_length),
Expand All @@ -247,28 +241,37 @@ def _model_multimodal_generate(
batch["input_pos"] = input_pos[None, :seq_len]
batch["mask"] = causal_mask[None, :seq_len]

# 3. Prefill step
generated_tokens = []
logits = self.model(prompt, **batch)[:, -1]
token = sample(logits, temperature=0.0, top_k=None)
generated_tokens.append(token.item())

cache_mask = batch["encoder_mask"][:, -1:]

# 4. Continue generating
for _ in range(max_length):
if token.item() in self.model_transform.stop_tokens:
break
logits = self.model(
token,
mask=causal_mask[None, seq_len, None, :],
encoder_input=None,
encoder_mask=cache_mask,
input_pos=input_pos[None, seq_len],
)[:, -1]
# 2. Setup KV cache
with local_kv_cache(
self.model,
batch_size=self.batch_size,
device=self.device,
dtype=self._dtype,
encoder_max_seq_len=encoder_max_seq_len,
decoder_max_seq_len=self.max_length,
):
# 3. Prefill step
generated_tokens = []
logits = self.model(prompt, **batch)[:, -1]
token = sample(logits, temperature=0.0, top_k=None)
generated_tokens.append(token.item())
seq_len += 1

cache_mask = batch["encoder_mask"][:, -1:]

# 4. Continue generating
for _ in range(max_length):
if token.item() in self.model_transform.stop_tokens:
break
logits = self.model(
token,
mask=causal_mask[None, seq_len, None, :],
encoder_input=None,
encoder_mask=cache_mask,
input_pos=input_pos[None, seq_len],
)[:, -1]
token = sample(logits, temperature=0.0, top_k=None)
generated_tokens.append(token.item())
seq_len += 1

# 5. Return generated tokens
return torch.tensor(generated_tokens, dtype=torch.int32).unsqueeze(0)
Expand Down Expand Up @@ -388,18 +391,6 @@ def _model_generate(
"Any decoding strategy other than greedy is not supported."
)

# Setup KV caches OR reset them if they're already set up
if self.enable_kv_cache:
if self.model.caches_are_enabled():
self.model.reset_caches()
else:
with self.device:
self.model.setup_caches(
batch_size=self.batch_size,
dtype=self._dtype,
decoder_max_seq_len=self.max_length,
)

# if we've recieved fewer than self._batch_size samples in the current
# batch we need to pad the batch out. here we're padding the end of the
# current batch to the correct length. this is because when we use static
Expand All @@ -409,15 +400,21 @@ def _model_generate(
(0, 0, 0, self._batch_size - bsz),
value=self._tokenizer.eos_id, # pad with one of the tokenizer's stop tokens so generation can stop early
)

toks, _ = generate(
with local_kv_cache(
self.model,
maybe_padded_context,
max_generated_tokens=self.max_gen_toks,
temperature=temperature,
top_k=None,
stop_tokens=self._tokenizer.stop_tokens,
)
batch_size=self.batch_size,
device=self.device,
dtype=self._dtype,
decoder_max_seq_len=self.max_length,
):
toks, _ = generate(
self.model,
maybe_padded_context,
max_generated_tokens=self.max_gen_toks,
temperature=temperature,
top_k=None,
stop_tokens=self._tokenizer.stop_tokens,
)
return toks[:bsz]


Expand Down Expand Up @@ -536,13 +533,6 @@ def evaluate(self) -> None:
# Initialize tasks for the harness
task_manager = TaskManager(include_path=self.include_path)
task_dict = get_task_dict(self.tasks, task_manager)
task_types = set([t.task.OUTPUT_TYPE for t in get_task_list(task_dict)])
if len(task_types) > 1 and "generate_until" in task_types:
raise RuntimeError(
"Evaluating on multiple task types where any one task involves "
"generation is currently not supported. See the issue below for more info: "
"https://github.com/pytorch/torchtune/issues/1621"
)

# Run evaluation
t0 = time.time()
Expand Down
9 changes: 0 additions & 9 deletions tests/recipes/test_eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,3 @@ def test_eval_recipe_errors_with_qat_quantizer(self, capsys, monkeypatch, tmpdir
match="QAT quantizers should only be used during quantization aware training",
):
runpy.run_path(TUNE_PATH, run_name="__main__")

@pytest.mark.integration_test
def test_eval_recipe_errors_with_generate_until_and_mc_tasks(
self, caplog, capsys, monkeypatch, tmpdir
):
# We can't currently specify both generate_until and mc_tasks in the same run
# b/c the KV cache won't be reset and the result will be different. This test
# catches that error
pass
20 changes: 12 additions & 8 deletions tests/torchtune/modules/model_fusion/test_fusion_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ def __init__(self, dim):
self.cache_enabled = False
self.encoder_max_seq_len = None

def setup_cache(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len):
def setup_caches(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len):
self.cache_enabled = True
self.encoder_max_seq_len = encoder_max_seq_len

def caches_are_enabled(self):
return self.cache_enabled

def reset_cache(self):
self.cache_enabled = False

Expand All @@ -43,10 +46,13 @@ def __init__(self, dim):
self.cache_enabled = False
self.decoder_max_seq_len = None

def setup_cache(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len):
def setup_caches(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len):
self.cache_enabled = True
self.decoder_max_seq_len = decoder_max_seq_len

def caches_are_enabled(self):
return self.cache_enabled

def reset_cache(self):
self.cache_enabled = False

Expand Down Expand Up @@ -131,22 +137,20 @@ def test_fusion_params(self, fused_layer):
"fusion_layer.linear.bias",
}

def test_setup_cache(self, fused_layer):
def test_setup_caches(self, fused_layer):
"""
Test that the cache methods works as expected.
"""
fused_layer.setup_cache(
fused_layer.setup_caches(
2, torch.float32, encoder_max_seq_len=10, decoder_max_seq_len=10
)
assert fused_layer.cache_enabled
fused_layer.reset_cache()
assert not fused_layer.cache_enabled
assert fused_layer.caches_are_enabled()

def test_setup_cache_different_cache_seq_len(self, fused_layer):
"""
Test that the cache methods works as expected.
"""
fused_layer.setup_cache(
fused_layer.setup_caches(
2, torch.float32, encoder_max_seq_len=5, decoder_max_seq_len=10
)

Expand Down
6 changes: 3 additions & 3 deletions tests/torchtune/modules/model_fusion/test_fusion_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, dim, vocab_size):
def setup_caches(self, batch_size, dtype, *args, **kwargs):
self.cache_enabled = True

def caches_are_enabled(self):
def caches_are_setup(self):
return self.cache_enabled

def reset_caches(self):
Expand Down Expand Up @@ -144,9 +144,9 @@ def test_setup_cache(self, fused_model):
Test that the cache methods works as expected.
"""
fused_model.setup_caches(2, torch.float32)
assert fused_model.caches_are_enabled()
assert fused_model.caches_are_setup()
fused_model.reset_caches()
assert not fused_model.caches_are_enabled()
assert not fused_model.caches_are_setup()

def test_set_trainable_params(self, fused_model, encoder, decoder):
"""
Expand Down
3 changes: 3 additions & 0 deletions tests/torchtune/modules/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def gqa_kv_cache(
kv_cache=kv_cache,
max_seq_len=max_seq_len,
)
attn.cache_enabled = True
fixed_init_model(attn)
attn.eval()
return attn
Expand Down Expand Up @@ -195,6 +196,7 @@ def mha_kv_cache(
kv_cache=kv_cache,
max_seq_len=max_seq_len,
)
attn.cache_enabled = True
fixed_init_model(attn)
attn.eval()
return attn
Expand Down Expand Up @@ -249,6 +251,7 @@ def mqa_kv_cache(
kv_cache=kv_cache,
max_seq_len=max_seq_len,
)
attn.cache_enabled = True
fixed_init_model(attn)
attn.eval()
return attn
Expand Down
Loading

0 comments on commit 73aa126

Please sign in to comment.