From 84e8cc5d94761145b2c6414ddfb869ba5beb81be Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Sun, 20 Oct 2024 23:15:07 +0100 Subject: [PATCH] tidying it allll up --- .../modules/model_fusion/test_fusion_layer.py | 2 +- torchtune/modules/model_fusion/_fusion.py | 13 +-- torchtune/modules/transformer.py | 85 +++++++++++-------- 3 files changed, 59 insertions(+), 41 deletions(-) diff --git a/tests/torchtune/modules/model_fusion/test_fusion_layer.py b/tests/torchtune/modules/model_fusion/test_fusion_layer.py index 1822c16e5d..a2fc0715eb 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_layer.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_layer.py @@ -137,7 +137,7 @@ def test_fusion_params(self, fused_layer): "fusion_layer.linear.bias", } - def test_setup_cache(self, fused_layer): + def test_setup_caches(self, fused_layer): """ Test that the cache methods works as expected. """ diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index cf488d3a02..40ede4feec 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -122,13 +122,16 @@ def setup_caches( ) def caches_are_setup(self) -> bool: - """Check if the key value caches have been setup.""" + """ + Check if the key value caches are setup on ``self.layer``. + See :func:~torchtune.modules.TransformerDecoder.caches_are_setup`. + """ return self.layer.caches_are_setup() def caches_are_enabled(self) -> bool: """ - Checks if the key value caches are enabled. KV-caches must also have been setup - for them to be enabled. + Checks if the key value caches on ``self.layer`` are enabled. + See :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`. """ return self.layer.caches_are_enabled() @@ -392,8 +395,8 @@ def setup_caches( def caches_are_setup(self) -> bool: """ - Check if the key value caches are setup. This means `setup_caches` has been called, and - the relevant attention modules in the model have created `KVCache`s. + Check if the key value caches are setup. This means ``setup_caches`` has been called, and + the relevant attention modules in the model have created their ``KVCache``. """ return self.decoder.caches_are_setup() diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 3b2d356c29..910cb8273b 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -63,14 +63,20 @@ def setup_caches( """ self.attn.setup_cache(batch_size, dtype, max_seq_len=decoder_max_seq_len) - def caches_are_enabled(self) -> bool: - """Check if key value caches are enabled.""" - return self.attn.cache_enabled - def caches_are_setup(self) -> bool: - """Check if the key value caches are setup.""" + """ + Check if the key value caches are setup on ``self.attn``. + See :func:~torchtune.modules.TransformerDecoder.caches_are_setup`. + """ return self.attn.kv_cache is not None + def caches_are_enabled(self) -> bool: + """ + Checks if the key value caches on ``self.attn`` are enabled. + See :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`. + """ + return self.attn.cache_enabled + def reset_cache(self): """Reset the key value caches.""" self.attn.reset_cache() @@ -188,25 +194,20 @@ def setup_caches( def caches_are_setup(self) -> bool: """ - Check if the key value caches are setup. This means `setup_caches` has been called, and - the relevant attention modules in the model have created `KVCache`s. + Check if the key value caches are setup on ``self.attn``. + See :func:~torchtune.modules.TransformerDecoder.caches_are_setup`. """ return self.attn.kv_cache is not None def caches_are_enabled(self) -> bool: """ - Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant - attention modules will be "enabled" and all forward passes will update the caches. This behaviour - can be disabled without altering the state of the KV-caches by "disabling" the KV-caches - using ``torchtune.modules.disable_kv_cache``, upon which ``caches_are_enabled`` would return False. + Checks if the key value caches on ``self.attn`` are enabled. + See :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`. """ return self.attn.cache_enabled def reset_cache(self): - """ - Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero, - without deleting or reallocating cache tensors. - """ + """Reset the key value caches.""" self.attn.reset_cache() def _skip_mask(self, mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]: @@ -270,7 +271,7 @@ def forward( """ # During decoding, it's possible encoder_input is None because the embeds # are already stored in the kv cache. - empty_cache = not self.caches_are_setup() or self.attn.kv_cache.size == 0 + empty_cache = not self.caches_are_enabled() or self.attn.kv_cache.size == 0 # Skip cross attention when no secondary input as it's primary purpose # is to attend between x and encoder_input. if encoder_input is None and empty_cache: @@ -448,12 +449,27 @@ def setup_caches( ) def caches_are_setup(self) -> bool: - """Check if the key value caches have been setup.""" + """ + Check if the key value caches are setup. This means ``setup_caches`` has been called, and + the relevant attention modules in the model have created their ``KVCache``. + """ return self.layers[0].caches_are_setup() + def caches_are_enabled(self) -> bool: + """ + Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant + attention modules will be "enabled" and all forward passes will update the caches. This behaviour + can be disabled without altering the state of the KV-caches by "disabling" the KV-caches + using ``torchtune.modules.disable_kv_cache``, upon which ``caches_are_enabled`` would return False. + """ + return self.layers[0].caches_are_enabled() + def reset_caches(self): - """Reset the key value caches.""" - if not self.caches_are_setup(): + """ + Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero, + without deleting or reallocating cache tensors. + """ + if not self.caches_are_enabled(): raise RuntimeError( "Key value caches are not setup. Call ``setup_caches()`` first." ) @@ -461,13 +477,6 @@ def reset_caches(self): for layer in self.layers: layer.reset_cache() - def caches_are_enabled(self) -> bool: - """ - Checks if the key value caches are enabled. KV-caches must also have been setup - for them to be enabled. - """ - return self.layers[0].caches_are_enabled() and self.caches_are_setup() - @torch.compiler.disable def chunked_output(self, last_hidden_state: torch.Tensor) -> List[torch.Tensor]: """ @@ -605,7 +614,13 @@ def forward( # input tensor of shape [b, s] seq_len = tokens.shape[1] - self._validate_inputs(seq_len, mask, encoder_input, encoder_mask, input_pos) + self._validate_inputs( + seq_len, + mask=mask, + encoder_input=encoder_input, + encoder_mask=encoder_mask, + input_pos=input_pos, + ) # shape: [b, s, d] h = self.tok_embeddings(tokens) @@ -777,7 +792,7 @@ def setup_caches( self.decoder_max_cache_seq_len = self.decoder_max_cache_seq_len for layer in self.layers: - layer.setup_cache( + layer.setup_caches( batch_size, dtype, self.encoder_max_cache_seq_len, @@ -785,20 +800,20 @@ def setup_caches( ) @property - def encoder_caches_are_setup(self) -> bool: + def encoder_caches_are_enabled(self) -> bool: """Checks if there are any :class:`~torchtune.modules.TransformerCrossAttentionLayer`, or :class:`~torchtune.modules.fusion.FusionLayer` layers which have cache enabled. """ return self.encoder_max_cache_seq_len is not None @property - def decoder_caches_are_setup(self) -> bool: + def decoder_caches_are_enabled(self) -> bool: """Check if the key value caches are setup.""" return self.decoder_max_cache_seq_len is not None def reset_caches(self): """Reset the key value caches.""" - if not (self.encoder_caches_are_setup or self.decoder_caches_are_setup): + if not (self.encoder_caches_are_enabled or self.decoder_caches_are_enabled): raise RuntimeError( "Key value caches are not setup. Call ``setup_caches()`` first." ) @@ -873,13 +888,13 @@ def forward( # shape: [b, s, d] h = self.tok_embeddings(tokens) - if self.decoder_caches_are_setup: + if self.decoder_caches_are_enabled: if mask is None: raise ValueError( "KV-caches for self-attention layers are setup for inference mode, masks must be provided!" " Use the `mask` arg to provide a mask." ) - if self.encoder_caches_are_setup: + if self.encoder_caches_are_enabled: if encoder_mask is None: raise ValueError( "KV-caches for cross-attention/fusion layers are setup for inference mode, encoder masks must be provided!" @@ -887,8 +902,8 @@ def forward( ) if ( - self.encoder_caches_are_setup - or self.decoder_caches_are_setup + self.encoder_caches_are_enabled + or self.decoder_caches_are_enabled and input_pos is None ): raise ValueError(