Skip to content

Commit

Permalink
tidying it allll up
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi committed Oct 20, 2024
1 parent f134e4c commit 84e8cc5
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 41 deletions.
2 changes: 1 addition & 1 deletion tests/torchtune/modules/model_fusion/test_fusion_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def test_fusion_params(self, fused_layer):
"fusion_layer.linear.bias",
}

def test_setup_cache(self, fused_layer):
def test_setup_caches(self, fused_layer):
"""
Test that the cache methods works as expected.
"""
Expand Down
13 changes: 8 additions & 5 deletions torchtune/modules/model_fusion/_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,16 @@ def setup_caches(
)

def caches_are_setup(self) -> bool:
"""Check if the key value caches have been setup."""
"""
Check if the key value caches are setup on ``self.layer``.
See :func:~torchtune.modules.TransformerDecoder.caches_are_setup`.
"""
return self.layer.caches_are_setup()

def caches_are_enabled(self) -> bool:
"""
Checks if the key value caches are enabled. KV-caches must also have been setup
for them to be enabled.
Checks if the key value caches on ``self.layer`` are enabled.
See :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`.
"""
return self.layer.caches_are_enabled()

Expand Down Expand Up @@ -392,8 +395,8 @@ def setup_caches(

def caches_are_setup(self) -> bool:
"""
Check if the key value caches are setup. This means `setup_caches` has been called, and
the relevant attention modules in the model have created `KVCache`s.
Check if the key value caches are setup. This means ``setup_caches`` has been called, and
the relevant attention modules in the model have created their ``KVCache``.
"""
return self.decoder.caches_are_setup()

Expand Down
85 changes: 50 additions & 35 deletions torchtune/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,20 @@ def setup_caches(
"""
self.attn.setup_cache(batch_size, dtype, max_seq_len=decoder_max_seq_len)

def caches_are_enabled(self) -> bool:
"""Check if key value caches are enabled."""
return self.attn.cache_enabled

def caches_are_setup(self) -> bool:
"""Check if the key value caches are setup."""
"""
Check if the key value caches are setup on ``self.attn``.
See :func:~torchtune.modules.TransformerDecoder.caches_are_setup`.
"""
return self.attn.kv_cache is not None

def caches_are_enabled(self) -> bool:
"""
Checks if the key value caches on ``self.attn`` are enabled.
See :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`.
"""
return self.attn.cache_enabled

def reset_cache(self):
"""Reset the key value caches."""
self.attn.reset_cache()
Expand Down Expand Up @@ -188,25 +194,20 @@ def setup_caches(

def caches_are_setup(self) -> bool:
"""
Check if the key value caches are setup. This means `setup_caches` has been called, and
the relevant attention modules in the model have created `KVCache`s.
Check if the key value caches are setup on ``self.attn``.
See :func:~torchtune.modules.TransformerDecoder.caches_are_setup`.
"""
return self.attn.kv_cache is not None

def caches_are_enabled(self) -> bool:
"""
Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant
attention modules will be "enabled" and all forward passes will update the caches. This behaviour
can be disabled without altering the state of the KV-caches by "disabling" the KV-caches
using ``torchtune.modules.disable_kv_cache``, upon which ``caches_are_enabled`` would return False.
Checks if the key value caches on ``self.attn`` are enabled.
See :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`.
"""
return self.attn.cache_enabled

def reset_cache(self):
"""
Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero,
without deleting or reallocating cache tensors.
"""
"""Reset the key value caches."""
self.attn.reset_cache()

def _skip_mask(self, mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
Expand Down Expand Up @@ -270,7 +271,7 @@ def forward(
"""
# During decoding, it's possible encoder_input is None because the embeds
# are already stored in the kv cache.
empty_cache = not self.caches_are_setup() or self.attn.kv_cache.size == 0
empty_cache = not self.caches_are_enabled() or self.attn.kv_cache.size == 0
# Skip cross attention when no secondary input as it's primary purpose
# is to attend between x and encoder_input.
if encoder_input is None and empty_cache:
Expand Down Expand Up @@ -448,26 +449,34 @@ def setup_caches(
)

def caches_are_setup(self) -> bool:
"""Check if the key value caches have been setup."""
"""
Check if the key value caches are setup. This means ``setup_caches`` has been called, and
the relevant attention modules in the model have created their ``KVCache``.
"""
return self.layers[0].caches_are_setup()

def caches_are_enabled(self) -> bool:
"""
Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant
attention modules will be "enabled" and all forward passes will update the caches. This behaviour
can be disabled without altering the state of the KV-caches by "disabling" the KV-caches
using ``torchtune.modules.disable_kv_cache``, upon which ``caches_are_enabled`` would return False.
"""
return self.layers[0].caches_are_enabled()

def reset_caches(self):
"""Reset the key value caches."""
if not self.caches_are_setup():
"""
Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero,
without deleting or reallocating cache tensors.
"""
if not self.caches_are_enabled():
raise RuntimeError(
"Key value caches are not setup. Call ``setup_caches()`` first."
)

for layer in self.layers:
layer.reset_cache()

def caches_are_enabled(self) -> bool:
"""
Checks if the key value caches are enabled. KV-caches must also have been setup
for them to be enabled.
"""
return self.layers[0].caches_are_enabled() and self.caches_are_setup()

@torch.compiler.disable
def chunked_output(self, last_hidden_state: torch.Tensor) -> List[torch.Tensor]:
"""
Expand Down Expand Up @@ -605,7 +614,13 @@ def forward(
# input tensor of shape [b, s]
seq_len = tokens.shape[1]

self._validate_inputs(seq_len, mask, encoder_input, encoder_mask, input_pos)
self._validate_inputs(
seq_len,
mask=mask,
encoder_input=encoder_input,
encoder_mask=encoder_mask,
input_pos=input_pos,
)

# shape: [b, s, d]
h = self.tok_embeddings(tokens)
Expand Down Expand Up @@ -777,28 +792,28 @@ def setup_caches(
self.decoder_max_cache_seq_len = self.decoder_max_cache_seq_len

for layer in self.layers:
layer.setup_cache(
layer.setup_caches(
batch_size,
dtype,
self.encoder_max_cache_seq_len,
self.decoder_max_cache_seq_len,
)

@property
def encoder_caches_are_setup(self) -> bool:
def encoder_caches_are_enabled(self) -> bool:
"""Checks if there are any :class:`~torchtune.modules.TransformerCrossAttentionLayer`,
or :class:`~torchtune.modules.fusion.FusionLayer` layers which have cache enabled.
"""
return self.encoder_max_cache_seq_len is not None

@property
def decoder_caches_are_setup(self) -> bool:
def decoder_caches_are_enabled(self) -> bool:
"""Check if the key value caches are setup."""
return self.decoder_max_cache_seq_len is not None

def reset_caches(self):
"""Reset the key value caches."""
if not (self.encoder_caches_are_setup or self.decoder_caches_are_setup):
if not (self.encoder_caches_are_enabled or self.decoder_caches_are_enabled):
raise RuntimeError(
"Key value caches are not setup. Call ``setup_caches()`` first."
)
Expand Down Expand Up @@ -873,22 +888,22 @@ def forward(
# shape: [b, s, d]
h = self.tok_embeddings(tokens)

if self.decoder_caches_are_setup:
if self.decoder_caches_are_enabled:
if mask is None:
raise ValueError(
"KV-caches for self-attention layers are setup for inference mode, masks must be provided!"
" Use the `mask` arg to provide a mask."
)
if self.encoder_caches_are_setup:
if self.encoder_caches_are_enabled:
if encoder_mask is None:
raise ValueError(
"KV-caches for cross-attention/fusion layers are setup for inference mode, encoder masks must be provided!"
" Use the `encoder_mask` arg to provide an encoder mask."
)

if (
self.encoder_caches_are_setup
or self.decoder_caches_are_setup
self.encoder_caches_are_enabled
or self.decoder_caches_are_enabled
and input_pos is None
):
raise ValueError(
Expand Down

0 comments on commit 84e8cc5

Please sign in to comment.