Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Toggling KV-caches #1763

Merged
merged 22 commits into from
Oct 20, 2024
Merged

Conversation

SalmanMohammadi
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi commented Oct 7, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

closes #1621
RFC here #1675


Multimodal eval results

On main

root@736fb59b1bb9:~/torchtune#  tune run eleuther_eval --config llama3_2_vision/evaluation limit=5 max_seq_length=2048
Running EleutherEvalRecipe with resolved config:

batch_size: 1
checkpointer:
  _component_: torchtune.training.FullModelMetaCheckpointer
  checkpoint_dir: /tmp/Llama-3.2-11B-Vision-Instruct/original
  checkpoint_files:
  - consolidated.pth
  model_type: LLAMA3_VISION
  output_dir: ./
device: cuda
dtype: bf16
enable_kv_cache: true
limit: 5
log_level: INFO
max_seq_length: 2048
model:
  _component_: torchtune.models.llama3_2_vision.llama3_2_vision_11b
quantizer: null
seed: 1234
tasks:
- mmmu_val_science
tokenizer:
  _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
  max_seq_len: 8192
  path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model

2024-10-09:15:51:36,147 INFO     [_logging.py:101] Running EleutherEvalRecipe with resolved config:

batch_size: 1
checkpointer:
  _component_: torchtune.training.FullModelMetaCheckpointer
  checkpoint_dir: /tmp/Llama-3.2-11B-Vision-Instruct/original
  checkpoint_files:
  - consolidated.pth
  model_type: LLAMA3_VISION
  output_dir: ./
device: cuda
dtype: bf16
enable_kv_cache: true
limit: 5
log_level: INFO
max_seq_length: 2048
model:
  _component_: torchtune.models.llama3_2_vision.llama3_2_vision_11b
quantizer: null
seed: 1234
tasks:
- mmmu_val_science
tokenizer:
  _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
  max_seq_len: 8192
  path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model

Model is initialized with precision torch.bfloat16.
2024-10-09:15:51:39,457 INFO     [eleuther_eval.py:500] Model is initialized with precision torch.bfloat16.
Running evaluation on the following tasks: ['mmmu_val_science']
2024-10-09:15:51:51,298 INFO     [eleuther_eval.py:549] Running evaluation on the following tasks: ['mmmu_val_science']
2024-10-09:15:51:51,302 INFO     [task.py:415] Building contexts for mmmu_val_biology on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 13001.56it/s]
2024-10-09:15:51:51,338 INFO     [task.py:415] Building contexts for mmmu_val_chemistry on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 20846.44it/s]
2024-10-09:15:51:51,342 INFO     [task.py:415] Building contexts for mmmu_val_geography on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 20846.44it/s]
2024-10-09:15:51:51,359 INFO     [task.py:415] Building contexts for mmmu_val_math on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 14246.96it/s]
2024-10-09:15:51:51,377 INFO     [task.py:415] Building contexts for mmmu_val_physics on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 21822.60it/s]
2024-10-09:15:51:51,392 INFO     [evaluator.py:489] Running generate_until requests
Running generate_until requests with text+image input: 100%|███████████████████████████████████████████| 25/25 [04:34<00:00, 10.98s/it]
Eval completed in 274.73 seconds.
2024-10-09:15:56:26,028 INFO     [eleuther_eval.py:558] Eval completed in 274.73 seconds.
Max memory allocated: 32.86 GB
2024-10-09:15:56:26,028 INFO     [eleuther_eval.py:559] Max memory allocated: 32.86 GB


|   Tasks    |Version|Filter|n-shot|Metric|   |Value|   |Stderr|
|------------|------:|------|------|------|---|----:|---|-----:|
|Science     |      0|none  |      |acc   || 0.32|±  |0.0938|
| - Biology  |      0|none  |None  |acc   || 0.20|±  |0.2000|
| - Chemistry|      0|none  |None  |acc   || 0.00|±  |0.0000|
| - Geography|      0|none  |None  |acc   || 0.40|±  |0.2449|
| - Math     |      0|none  |None  |acc   || 0.40|±  |0.2449|
| - Physics  |      0|none  |None  |acc   || 0.60|±  |0.2449|


2024-10-09:15:56:26,086 INFO     [eleuther_eval.py:563] 

|   Tasks    |Version|Filter|n-shot|Metric|   |Value|   |Stderr|
|------------|------:|------|------|------|---|----:|---|-----:|
|Science     |      0|none  |      |acc   || 0.32|±  |0.0938|
| - Biology  |      0|none  |None  |acc   || 0.20|±  |0.2000|
| - Chemistry|      0|none  |None  |acc   || 0.00|±  |0.0000|
| - Geography|      0|none  |None  |acc   || 0.40|±  |0.2449|
| - Math     |      0|none  |None  |acc   || 0.40|±  |0.2449|
| - Physics  |      0|none  |None  |acc   || 0.60|±  |0.2449|


On this branch

root@736fb59b1bb9:~/torchtune#  tune run eleuther_eval --config llama3_2_vision/evaluation limit=5 max_seq_length=2048
Running EleutherEvalRecipe with resolved config:

batch_size: 1
checkpointer:
  _component_: torchtune.training.FullModelMetaCheckpointer
  checkpoint_dir: /tmp/Llama-3.2-11B-Vision-Instruct/original
  checkpoint_files:
  - consolidated.pth
  model_type: LLAMA3_VISION
  output_dir: ./
device: cuda
dtype: bf16
enable_kv_cache: true
limit: 5
log_level: INFO
max_seq_length: 2048
model:
  _component_: torchtune.models.llama3_2_vision.llama3_2_vision_11b
quantizer: null
seed: 1234
tasks:
- mmmu_val_science
tokenizer:
  _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
  max_seq_len: 8192
  path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model

2024-10-09:15:45:53,771 INFO     [_logging.py:101] Running EleutherEvalRecipe with resolved config:

batch_size: 1
checkpointer:
  _component_: torchtune.training.FullModelMetaCheckpointer
  checkpoint_dir: /tmp/Llama-3.2-11B-Vision-Instruct/original
  checkpoint_files:
  - consolidated.pth
  model_type: LLAMA3_VISION
  output_dir: ./
device: cuda
dtype: bf16
enable_kv_cache: true
limit: 5
log_level: INFO
max_seq_length: 2048
model:
  _component_: torchtune.models.llama3_2_vision.llama3_2_vision_11b
quantizer: null
seed: 1234
tasks:
- mmmu_val_science
tokenizer:
  _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
  max_seq_len: 8192
  path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model

Model is initialized with precision torch.bfloat16.
2024-10-09:15:45:57,532 INFO     [eleuther_eval.py:500] Model is initialized with precision torch.bfloat16.
Running evaluation on the following tasks: ['mmmu_val_science']
2024-10-09:15:46:10,590 INFO     [eleuther_eval.py:549] Running evaluation on the following tasks: ['mmmu_val_science']
2024-10-09:15:46:10,594 INFO     [task.py:415] Building contexts for mmmu_val_biology on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 12336.19it/s]
2024-10-09:15:46:10,631 INFO     [task.py:415] Building contexts for mmmu_val_chemistry on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 20301.57it/s]
2024-10-09:15:46:10,635 INFO     [task.py:415] Building contexts for mmmu_val_geography on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 15911.62it/s]
2024-10-09:15:46:10,653 INFO     [task.py:415] Building contexts for mmmu_val_math on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 13626.72it/s]
2024-10-09:15:46:10,671 INFO     [task.py:415] Building contexts for mmmu_val_physics on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 20223.26it/s]
2024-10-09:15:46:10,687 INFO     [evaluator.py:489] Running generate_until requests
Running generate_until requests with text+image input: 100%|███████████████████████████████████████████| 25/25 [04:33<00:00, 10.92s/it]
Eval completed in 273.26 seconds.
2024-10-09:15:50:43,852 INFO     [eleuther_eval.py:558] Eval completed in 273.26 seconds.
Max memory allocated: 32.86 GB
2024-10-09:15:50:43,852 INFO     [eleuther_eval.py:559] Max memory allocated: 32.86 GB


|   Tasks    |Version|Filter|n-shot|Metric|   |Value|   |Stderr|
|------------|------:|------|------|------|---|----:|---|-----:|
|Science     |      0|none  |      |acc   || 0.32|±  |0.0938|
| - Biology  |      0|none  |None  |acc   || 0.20|±  |0.2000|
| - Chemistry|      0|none  |None  |acc   || 0.00|±  |0.0000|
| - Geography|      0|none  |None  |acc   || 0.40|±  |0.2449|
| - Math     |      0|none  |None  |acc   || 0.40|±  |0.2449|
| - Physics  |      0|none  |None  |acc   || 0.60|±  |0.2449|


2024-10-09:15:50:43,909 INFO     [eleuther_eval.py:563] 

|   Tasks    |Version|Filter|n-shot|Metric|   |Value|   |Stderr|
|------------|------:|------|------|------|---|----:|---|-----:|
|Science     |      0|none  |      |acc   || 0.32|±  |0.0938|
| - Biology  |      0|none  |None  |acc   || 0.20|±  |0.2000|
| - Chemistry|      0|none  |None  |acc   || 0.00|±  |0.0000|
| - Geography|      0|none  |None  |acc   || 0.40|±  |0.2449|
| - Math     |      0|none  |None  |acc   || 0.40|±  |0.2449|
| - Physics  |      0|none  |None  |acc   || 0.60|±  |0.2449|

Text eval results

(tune) salman@combuter:~/torchtune$ tune run eleuther_eval --config target/eleuther_evaluation.yaml 
2024-10-08:21:18:07,202 INFO     [_logging.py:101] Running EleutherEvalRecipe with resolved config:

batch_size: 1
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: ./target/1b_normal
  checkpoint_files:
  - pytorch_model.bin
  model_type: LLAMA2
  output_dir: ./target/tmp
device: cuda
dtype: bf16
enable_kv_cache: true
limit: 20
max_seq_length: 1024
model:
  _component_: torchtune.models.llama2.llama2
  embed_dim: 2048
  max_seq_len: 4096
  norm_eps: 1.0e-05
  num_heads: 32
  num_kv_heads: 4
  num_layers: 22
  vocab_size: 32000
quantizer: null
seed: 1234
tasks:
- truthfulqa_gen
- truthfulqa_mc2
tokenizer:
  _component_: torchtune.models.llama2.llama2_tokenizer
  path: ./target/1b_normal/tokenizer.model

2024-10-08:21:18:08,854 INFO     [eleuther_eval.py:495] Model is initialized with precision torch.bfloat16.
2024-10-08:21:18:08,879 INFO     [huggingface.py:132] Using device 'cuda:0'
/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
2024-10-08:21:18:09,228 INFO     [huggingface.py:368] Model parallel was set to False, max memory was not set, and device map was set to {'': 'cuda:0'}
2024-10-08:21:18:10,532 INFO     [__init__.py:491] `group` and `group_alias` keys in TaskConfigs are deprecated and will be removed in v0.4.5 of lm_eval. The new `tag` field will be used to allow for a shortcut to a group of tasks one does not wish to aggregate metrics across. `group`s which aggregate across subtasks must be only defined in a separate group config file, which will be the official way to create groups that support cross-task aggregation as in `mmlu`. Please see the v0.4.4 patch notes and our documentation: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/new_task_guide.md#advanced-group-configs for more information.
2024-10-08:21:18:23,103 INFO     [eleuther_eval.py:537] Running evaluation on the following tasks: ['truthfulqa_gen', 'truthfulqa_mc2']
2024-10-08:21:18:23,106 INFO     [task.py:428] Building contexts for truthfulqa_mc2 on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 903.93it/s]
2024-10-08:21:18:23,130 INFO     [task.py:428] Building contexts for truthfulqa_gen on rank 0...
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 1604.19it/s]
2024-10-08:21:18:23,147 INFO     [evaluator.py:485] Running loglikelihood requests
Running loglikelihood requests: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 153/153 [00:18<00:00,  8.37it/s]
2024-10-08:21:18:41,510 INFO     [evaluator.py:485] Running generate_until requests
Running generate_until requests: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [02:48<00:00,  8.43s/it]
2024-10-08:21:21:30,179 INFO     [rouge_scorer.py:83] Using default tokenizer.
2024-10-08:21:21:49,314 INFO     [eleuther_eval.py:546] Eval completed in 206.21 seconds.
2024-10-08:21:21:49,314 INFO     [eleuther_eval.py:547] Max memory allocated: 3.41 GB
2024-10-08:21:21:49,432 INFO     [eleuther_eval.py:551] 

|    Tasks     |Version|Filter|n-shot|  Metric   |   | Value  |   |Stderr|
|--------------|------:|------|-----:|-----------|---|-------:|---|-----:|
|truthfulqa_gen|      3|none  |     0|bleu_acc   ||  0.3000|±  |0.1051|
|              |       |none  |     0|bleu_diff  ||-11.1260|±  |3.5818|
|              |       |none  |     0|bleu_max   || 23.0317|±  |4.5452|
|              |       |none  |     0|rouge1_acc ||  0.4500|±  |0.1141|
|              |       |none  |     0|rouge1_diff|| -6.7262|±  |3.7765|
|              |       |none  |     0|rouge1_max || 49.9555|±  |4.9504|
|              |       |none  |     0|rouge2_acc ||  0.3000|±  |0.1051|
|              |       |none  |     0|rouge2_diff||-12.0537|±  |4.2492|
|              |       |none  |     0|rouge2_max || 31.7591|±  |6.2387|
|              |       |none  |     0|rougeL_acc ||  0.3500|±  |0.1094|
|              |       |none  |     0|rougeL_diff|| -8.0059|±  |3.4676|
|              |       |none  |     0|rougeL_max || 47.6440|±  |5.3441|
|truthfulqa_mc2|      2|none  |     0|acc        ||  0.4769|±  |0.0947|

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Oct 7, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1763

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7906807 with merge base 3ca0d30 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 7, 2024
@codecov-commenter
Copy link

codecov-commenter commented Oct 8, 2024

Codecov Report

Attention: Patch coverage is 18.98734% with 64 lines in your changes missing coverage. Please review.

Project coverage is 25.70%. Comparing base (7cf656b) to head (480c2b3).
Report is 7 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/modules/common_utils.py 20.58% 27 Missing ⚠️
recipes/eleuther_eval.py 0.00% 19 Missing ⚠️
torchtune/modules/transformer.py 41.66% 7 Missing ⚠️
tests/torchtune/modules/test_attention.py 0.00% 3 Missing ⚠️
torchtune/modules/attention.py 0.00% 3 Missing ⚠️
...rchtune/modules/model_fusion/test_fusion_models.py 33.33% 2 Missing ⚠️
torchtune/generation/_generation.py 0.00% 1 Missing ⚠️
torchtune/models/gemma/transformer.py 0.00% 1 Missing ⚠️
torchtune/modules/model_fusion/_fusion.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1763       +/-   ##
===========================================
- Coverage   69.33%   25.70%   -43.63%     
===========================================
  Files         305      305               
  Lines       15892    16003      +111     
===========================================
- Hits        11018     4113     -6905     
- Misses       4874    11890     +7016     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@SalmanMohammadi SalmanMohammadi marked this pull request as ready for review October 8, 2024 20:43
@SalmanMohammadi
Copy link
Collaborator Author

I've actually just realized this isn't going to play nice with fusion models and we may have to pollute TransformerDecoder a bit here. Bear with me.

torchtune/modules/common_utils.py Outdated Show resolved Hide resolved
torchtune/modules/common_utils.py Outdated Show resolved Hide resolved


@contextlib.contextmanager
def setup_use_local_kv_cache(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I don't like this name (I also don't have a great suggestion that isn't extremely long). I do think we should change local -> temporary, and I would like to say "remove setup from the name". I'm on the fence about that not being explicit enough, but I kinda think temporary implies that we are constructing and destroying as part of the context manager. What do you think?

torchtune/modules/common_utils.py Show resolved Hide resolved
torchtune/modules/common_utils.py Outdated Show resolved Hide resolved
torchtune/modules/common_utils.py Outdated Show resolved Hide resolved
torchtune/modules/common_utils.py Show resolved Hide resolved
torchtune/generation/_generation.py Outdated Show resolved Hide resolved
torchtune/modules/transformer.py Outdated Show resolved Hide resolved
@@ -199,12 +199,3 @@ def test_eval_recipe_errors_with_qat_quantizer(self, capsys, monkeypatch, tmpdir
match="QAT quantizers should only be used during quantization aware training",
):
runpy.run_path(TUNE_PATH, run_name="__main__")

@pytest.mark.integration_test
def test_eval_recipe_errors_with_generate_until_and_mc_tasks(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we instead have a test that this now actually works?

torchtune/modules/attention.py Show resolved Hide resolved
@SalmanMohammadi
Copy link
Collaborator Author

Yeah I think this touches on my main point of confusion with this PR: the distinction between caches_are_enabled and caches_are_setup. I thought I had an understanding of this but now I'm not so sure. It seems to me like some of the places we now call caches_are_setup (e.g. checking whether mask or input pos is passed in forward) should actually be calling caches_are_enabled.

I think I see where you're coming from here. I think what would help disambiguate is having model.caches_are_enabled(), but I tried my hardest not to add additional functions to TransformerDecoder and DeepFusionModel. In retrospect this was a silly goal and things would be significantly clearer if we had that method.

I think going forward, and to significantly simplify things, I'd propose:

  1. The existing API around setup_caches is unchanged. You may setup_caches, then use KV-cacheing if you wish.
  2. We add a with temporary_kv_cache or with local_kv_cache or with kv_cache which will enable KV-cacheing on a model that does not already have KV-cacheing enabled, and completely tear-down upon exit.
  3. We add a with disable_kv_cache which disables KV-cacheing on a model that already has KV-cacheing enabled.
  4. We add a model.caches_are_enabled()

This means

  1. We aren't making any breaking changes
  2. We're cementing the distinction between caches_are_enabled and caches_are_setup s.t. most of the checks we make inside TransformerDecoder for input validation will use caches_are_enabled.

@ebsmothers
Copy link
Contributor

Yeah I think this touches on my main point of confusion with this PR: the distinction between caches_are_enabled and caches_are_setup. I thought I had an understanding of this but now I'm not so sure. It seems to me like some of the places we now call caches_are_setup (e.g. checking whether mask or input pos is passed in forward) should actually be calling caches_are_enabled.

I think I see where you're coming from here. I think what would help disambiguate is having model.caches_are_enabled(), but I tried my hardest not to add additional functions to TransformerDecoder and DeepFusionModel. In retrospect this was a silly goal and things would be significantly clearer if we had that method.

I think going forward, and to significantly simplify things, I'd propose:

  1. The existing API around setup_caches is unchanged. You may setup_caches, then use KV-cacheing if you wish.
  2. We add a with temporary_kv_cache or with local_kv_cache or with kv_cache which will enable KV-cacheing on a model that does not already have KV-cacheing enabled, and completely tear-down upon exit.
  3. We add a with disable_kv_cache which disables KV-cacheing on a model that already has KV-cacheing enabled.
  4. We add a model.caches_are_enabled()

This means

  1. We aren't making any breaking changes
  2. We're cementing the distinction between caches_are_enabled and caches_are_setup s.t. most of the checks we make inside TransformerDecoder for input validation will use caches_are_enabled.

Overall this sounds pretty reasonable to me. My main question is around (3). If we just locally disable KV cache, when do we reset (if at all)? Like if I setup the cache, call forward a bunch of times (with it now enabled by default), wrap a few forward calls in with disable_kv_cache, what will happen on my next forward outside the disable_kv_cache context manager? Am I starting from scratch or resuming where I left off?

And then does this also mean that we now have both caches_are_enabled and caches_are_setup properties in TransformerDecoder and DeepFusionModel? Just to be explicit, can we clarify which classes we are defining each of these methods on (and why) across the board? I think that may help add some method to the madness here

@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Oct 10, 2024

My main question is around (3). If we just locally disable KV cache, when do we reset (if at all)? Like if I setup the cache, call forward a bunch of times (with it now enabled by default), wrap a few forward calls in with disable_kv_cache, what will happen on my next forward outside the disable_kv_cache context manager? Am I starting from scratch or resuming where I left off?

The onus is on the user here to correctly .reset_caches whenever they're finished with a particular generation. I think we should be clear that the behaviour of with disable_kv_cache will mean your KV-caches are untouched and control of them will be returned to you upon exit. As it stands it pretty much only exists for the PPO recipe.

And then does this also mean that we now have both caches_are_enabled and caches_are_setup properties in TransformerDecoder and DeepFusionModel? Just to be explicit, can we clarify which classes we are defining each of these methods on (and why) across the board?

Sorry, yes.

  • We already have caches_are_enabled on both TransformerDecoder and DeepFusion. I'd like to add caches_are_setup which will validate operations on the KV-caches which we can always perform (regardless of their enabled-ness) - model.reset_caches, delete_caches(model), model.setup_caches.
  • However, it seems sensible to also have a check for operations which rely the enabled-ness of the caches, i.e. model.forward will use self.caches_are_enabled to check if you're correctly passing inference-relevant args.

EDIT:
Okay total miss from me, sorry. We already have caches_are_enabled, but in this PR I'd like to add caches_are_setup. Fixed above.

@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Oct 11, 2024

As another side note here, we currently expect attention layers TransformerSelfAttentionLayer and TransformerCrossAttentionLayer to implement self.setup_cache, self.cache_is_setup, self.reset_cache, but all these functions do is call the same method on self.attn.

This is the same across both kinds of attention layers.

We're already showing here that we don't need this by doing things like

    for module in model.modules():
        if hasattr(module, "kv_cache") and callable(module.kv_cache):
            module.cache_enabled = False
            module.kv_cache = None

so why don't we get rid of the intermediary functions and use the API on MultiHeadAttention directly?
@ebsmothers @joecummings @pbontrager

edit: sorry, I know I'm terrible for over-scoping PRs but I can't help it.

@joecummings
Copy link
Contributor

As another side note here, we currently expect attention layers TransformerSelfAttentionLayer and TransformerCrossAttentionLayer to implement self.setup_cache, self.cache_is_setup, self.reset_cache, but all these functions do is call the same method on self.attn.

This is the same across both kinds of attention layers.

We're already showing here that we don't need this by doing things like

    for module in model.modules():
        if hasattr(module, "kv_cache") and callable(module.kv_cache):
            module.cache_enabled = False
            module.kv_cache = None

so why don't we get rid of the intermediary functions and use the API on MultiHeadAttention directly? @ebsmothers @joecummings @pbontrager

edit: sorry, I know I'm terrible for over-scoping PRs but I can't help it.

You're right, I think we should be doing this but BAD SALMAN STOP OVERSCOPING PRs. Let's open a separate issue for follow-up?

@joecummings joecummings mentioned this pull request Oct 15, 2024
36 tasks
@ebsmothers
Copy link
Contributor

As another side note here, we currently expect attention layers TransformerSelfAttentionLayer and TransformerCrossAttentionLayer to implement self.setup_cache, self.cache_is_setup, self.reset_cache, but all these functions do is call the same method on self.attn.
This is the same across both kinds of attention layers.
We're already showing here that we don't need this by doing things like

    for module in model.modules():
        if hasattr(module, "kv_cache") and callable(module.kv_cache):
            module.cache_enabled = False
            module.kv_cache = None

so why don't we get rid of the intermediary functions and use the API on MultiHeadAttention directly? @ebsmothers @joecummings @pbontrager
edit: sorry, I know I'm terrible for over-scoping PRs but I can't help it.

You're right, I think we should be doing this but BAD SALMAN STOP OVERSCOPING PRs. Let's open a separate issue for follow-up?

Yeah agreed, let's tackle this separately. Sorry I slept on this PR for a few days, what's the status here? Like are we ready for another pass or are there still open questions we need to sort out?

@SalmanMohammadi
Copy link
Collaborator Author

Yeah this is good to review again. I'll fix merge conflicts/add in the eval test you requested later this week.

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple nits, but I'm ready to sign off on this PR.

torchtune/modules/attention.py Outdated Show resolved Hide resolved
torchtune/modules/model_fusion/_fusion.py Outdated Show resolved Hide resolved
torchtune/modules/model_fusion/_fusion.py Show resolved Hide resolved
torchtune/modules/transformer.py Outdated Show resolved Hide resolved
torchtune/modules/transformer.py Show resolved Hide resolved
torchtune/modules/transformer.py Outdated Show resolved Hide resolved
@SalmanMohammadi SalmanMohammadi merged commit 73aa126 into pytorch:main Oct 20, 2024
17 checks passed
@SalmanMohammadi SalmanMohammadi deleted the toggle_kv_cache branch October 21, 2024 14:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fix eval recipe for consecutive generation and non-generation tasks
6 participants