Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Toggling KV-caches #1763

Merged
merged 22 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ These are utilities that are common to and can be used by all modules.
:nosignatures:

common_utils.reparametrize_as_dtype_state_dict_post_hook
common_utils.setup_use_local_kv_cache
common_utils.use_persistent_kv_cache
common_utils.delete_kv_caches


Vision Transforms
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_2_vision/evaluation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ tasks: ["mmmu_val_science"] # Defaulting to science as a good subset
limit: null
batch_size: 1
enable_kv_cache: True
max_seq_length: 8192

# Quantization specific args
# Quantization is not supported in this specific config
Expand Down
108 changes: 49 additions & 59 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from torchtune.generation import generate, sample
from torchtune.modules import TransformerDecoder
from torchtune.modules.common_utils import setup_use_local_kv_cache
from torchtune.modules.model_fusion import DeepFusionModel
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform
Expand All @@ -47,7 +48,7 @@
)
sys.exit(1)

from lm_eval.evaluator import evaluate, get_task_list
from lm_eval.evaluator import evaluate

# User doesn't have to have nightlies installed, they just won't be able
# to use the multimodal model
Expand Down Expand Up @@ -253,18 +254,11 @@ def _model_multimodal_generate(
"multimodal generation."
)

# 2. Setup KV cache and masks for bsz 1
encoder_max_seq_len = (
self.model_transform.image_seq_len * self._max_images_per_sample
)
# Setup masks for bsz 1
with self.device:
if self.model.caches_are_enabled():
self.model.reset_caches()
else:
self.model.setup_caches(
batch_size=1,
dtype=self._dtype,
encoder_max_seq_len=self.model_transform.image_seq_len
* self._max_images_per_sample,
decoder_max_seq_len=self.max_length,
)
causal_mask = torch.tril(
torch.ones(
size=(self.max_length, self.max_length),
Expand All @@ -276,28 +270,37 @@ def _model_multimodal_generate(
batch["input_pos"] = input_pos[None, :seq_len]
batch["mask"] = causal_mask[None, :seq_len]

# 3. Prefill step
generated_tokens = []
logits = self.model(prompt, **batch)[:, -1]
token = sample(logits, temperature=0.0, top_k=None)
generated_tokens.append(token.item())

cache_mask = batch["encoder_mask"][:, -1:]

# 4. Continue generating
for _ in range(max_length):
if token.item() in self.model_transform.stop_tokens:
break
logits = self.model(
token,
mask=causal_mask[None, seq_len, None, :],
encoder_input=None,
encoder_mask=cache_mask,
input_pos=input_pos[None, seq_len],
)[:, -1]
# 2. Setup KV cache
with setup_use_local_kv_cache(
self.model,
batch_size=self.batch_size,
device=self.device,
dtype=self._dtype,
encoder_max_seq_len=encoder_max_seq_len,
decoder_max_seq_len=self.max_length,
):
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved
# 3. Prefill step
generated_tokens = []
logits = self.model(prompt, **batch)[:, -1]
token = sample(logits, temperature=0.0, top_k=None)
generated_tokens.append(token.item())
seq_len += 1

cache_mask = batch["encoder_mask"][:, -1:]

# 4. Continue generating
for _ in range(max_length):
if token.item() in self.model_transform.stop_tokens:
break
logits = self.model(
token,
mask=causal_mask[None, seq_len, None, :],
encoder_input=None,
encoder_mask=cache_mask,
input_pos=input_pos[None, seq_len],
)[:, -1]
token = sample(logits, temperature=0.0, top_k=None)
generated_tokens.append(token.item())
seq_len += 1

# 5. Return generated tokens
return torch.tensor(generated_tokens, dtype=torch.int32).unsqueeze(0)
Expand Down Expand Up @@ -417,18 +420,6 @@ def _model_generate(
"Any decoding strategy other than greedy is not supported."
)

# Setup KV caches OR reset them if they're already set up
if self.enable_kv_cache:
if self.model.caches_are_enabled():
self.model.reset_caches()
else:
with self.device:
self.model.setup_caches(
batch_size=self.batch_size,
dtype=self._dtype,
decoder_max_seq_len=self.max_length,
)

# if we've recieved fewer than self._batch_size samples in the current
# batch we need to pad the batch out. here we're padding the end of the
# current batch to the correct length. this is because when we use static
Expand All @@ -438,15 +429,21 @@ def _model_generate(
(0, 0, 0, self._batch_size - bsz),
value=self._tokenizer.eos_id, # pad with one of the tokenizer's stop tokens so generation can stop early
)

toks, _ = generate(
with setup_use_local_kv_cache(
self.model,
maybe_padded_context,
max_generated_tokens=self.max_gen_toks,
temperature=temperature,
top_k=None,
stop_tokens=self._tokenizer.stop_tokens,
)
batch_size=self.batch_size,
device=self.device,
dtype=self._dtype,
decoder_max_seq_len=self.max_length,
):
toks, _ = generate(
self.model,
maybe_padded_context,
max_generated_tokens=self.max_gen_toks,
temperature=temperature,
top_k=None,
stop_tokens=self._tokenizer.stop_tokens,
)
return toks[:bsz]


Expand Down Expand Up @@ -555,13 +552,6 @@ def evaluate(self) -> None:
# Initialize tasks for the harness
task_manager = TaskManager(include_path=self.include_path)
task_dict = get_task_dict(self.tasks, task_manager)
task_types = set([t.task.OUTPUT_TYPE for t in get_task_list(task_dict)])
if len(task_types) > 1 and "generate_until" in task_types:
raise RuntimeError(
"Evaluating on multiple task types where any one task involves "
"generation is currently not supported. See the issue below for more info: "
"https://github.com/pytorch/torchtune/issues/1621"
)

# Run evaluation
t0 = time.time()
Expand Down
9 changes: 0 additions & 9 deletions tests/recipes/test_eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,3 @@ def test_eval_recipe_errors_with_qat_quantizer(self, capsys, monkeypatch, tmpdir
match="QAT quantizers should only be used during quantization aware training",
):
runpy.run_path(TUNE_PATH, run_name="__main__")

@pytest.mark.integration_test
def test_eval_recipe_errors_with_generate_until_and_mc_tasks(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we instead have a test that this now actually works?

self, caplog, capsys, monkeypatch, tmpdir
):
# We can't currently specify both generate_until and mc_tasks in the same run
# b/c the KV cache won't be reset and the result will be different. This test
# catches that error
pass
6 changes: 3 additions & 3 deletions tests/torchtune/modules/model_fusion/test_fusion_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, dim, vocab_size):
def setup_caches(self, batch_size, dtype, *args, **kwargs):
self.cache_enabled = True

def caches_are_enabled(self):
def caches_are_setup(self):
return self.cache_enabled

def reset_caches(self):
Expand Down Expand Up @@ -144,9 +144,9 @@ def test_setup_cache(self, fused_model):
Test that the cache methods works as expected.
"""
fused_model.setup_caches(2, torch.float32)
assert fused_model.caches_are_enabled()
assert fused_model.caches_are_setup()
fused_model.reset_caches()
assert not fused_model.caches_are_enabled()
assert not fused_model.caches_are_setup()

def test_set_trainable_params(self, fused_model, encoder, decoder):
"""
Expand Down
3 changes: 3 additions & 0 deletions tests/torchtune/modules/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def gqa_kv_cache(
kv_cache=kv_cache,
max_seq_len=max_seq_len,
)
attn.cache_enabled = True
fixed_init_model(attn)
attn.eval()
return attn
Expand Down Expand Up @@ -195,6 +196,7 @@ def mha_kv_cache(
kv_cache=kv_cache,
max_seq_len=max_seq_len,
)
attn.cache_enabled = True
fixed_init_model(attn)
attn.eval()
return attn
Expand Down Expand Up @@ -249,6 +251,7 @@ def mqa_kv_cache(
kv_cache=kv_cache,
max_seq_len=max_seq_len,
)
attn.cache_enabled = True
fixed_init_model(attn)
attn.eval()
return attn
Expand Down
4 changes: 2 additions & 2 deletions torchtune/generation/_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def generate(
total_response_length = prompt_length + max_generated_tokens

generated_tokens = prompt.clone()
incremental_decoding = model.caches_are_enabled()
incremental_decoding = model.caches_are_setup()

# grab the correct max_seq_len to generate full causal masks/position ids
# this is the model's max cache len if incremental decoding, or the sequence
Expand Down Expand Up @@ -366,7 +366,7 @@ def generate(
tokens, logits = custom_generate_next_token(
model,
input_pos=curr_input_pos,
x=tokens,
x=tokens.clone(),
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved
mask=curr_masks,
temperature=temperature,
top_k=top_k,
Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/clip/_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def _load_state_dict_hook(
if inpt_num_tokens != tgt_num_tokens or inpt_emb != tgt_emb:
raise ValueError(
"Expected embedding shape to be (..., num_tokens, tgt_emb) to match"
f" but found shapes {self.embedding.shape} and {state_dict[prefix+'embedding'].shape}"
f" but found shapes {self.embedding.shape} and {state_dict[prefix + 'embedding'].shape}"
)

if inpt_max_num_tiles_x != inpt_max_num_tiles_y:
Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
self.norm_embeddings = norm_embeddings
self.num_output_chunks = 0

def caches_are_enabled(self) -> bool:
def caches_are_setup(self) -> bool:
"""Check if the key value caches are setup."""
return self.layers[0].cache_enabled

Expand Down
10 changes: 9 additions & 1 deletion torchtune/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@

from .attention import MultiHeadAttention # noqa
from .attention_utils import create_block_causal_mask, packed_block_causal_mask
from .common_utils import reparametrize_as_dtype_state_dict_post_hook
from .common_utils import (
delete_kv_caches,
reparametrize_as_dtype_state_dict_post_hook,
setup_use_local_kv_cache,
use_persistent_kv_cache,
)
from .feed_forward import FeedForward # noqa
from .kv_cache import KVCache # noqa
from .layer_norm import Fp32LayerNorm # noqa
Expand Down Expand Up @@ -43,4 +48,7 @@
"reparametrize_as_dtype_state_dict_post_hook",
"create_block_causal_mask",
"packed_block_causal_mask",
"setup_use_local_kv_cache",
"delete_kv_caches",
"use_persistent_kv_cache",
]
8 changes: 7 additions & 1 deletion torchtune/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ def __init__(
# Use flex attention if supported and we are sample packing
self._attention_call = _sdpa_or_flex_attention()

# this flag indicates whether to update the kv-cache during forward
# passes. when disabled, we can have the cache setup but still
# perform normal forward passes
self.cache_enabled = False

def setup_cache(
self, batch_size: int, dtype: torch.dtype, max_seq_len: int
) -> None:
Expand All @@ -164,6 +169,7 @@ def setup_cache(
head_dim=self.head_dim,
dtype=dtype,
)
self.cache_enabled = True
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved

def reset_cache(self):
"""Reset the key value caches."""
Expand Down Expand Up @@ -291,7 +297,7 @@ def forward(
k = self.k_norm(k)

# Update key-value cache
if self.kv_cache is not None:
if self.kv_cache is not None and self.cache_enabled:
k, v = self.kv_cache.update(k, v)

output = self._attention_call(
Expand Down
Loading
Loading