From a80f35d4251754ed0e1b14bd62817f459a9fe683 Mon Sep 17 00:00:00 2001 From: Alex Trott Date: Thu, 30 Mar 2023 08:39:01 -0700 Subject: [PATCH] Support only-within-sequence attention for MosaicGPT (#266) Makes it possible to have attention restricted to tokens within the same source sequence when using pre-concatenated text dataloading. --- examples/common/text_data.py | 57 ++++++++++++++++- examples/llm/src/models/layers/attention.py | 11 ++-- .../mosaic_gpt/configuration_mosaic_gpt.py | 12 ++++ .../llm/src/models/mosaic_gpt/mosaic_gpt.py | 63 +++++++++++++++++-- examples/llm/tests/test_dataloader.py | 26 +++++++- examples/llm/tests/test_flash_triton_torch.py | 1 + examples/llm/tests/test_model.py | 51 +++++++++++++++ 7 files changed, 208 insertions(+), 13 deletions(-) diff --git a/examples/common/text_data.py b/examples/common/text_data.py index 37f85bea4..1af3da5d8 100644 --- a/examples/common/text_data.py +++ b/examples/common/text_data.py @@ -5,7 +5,7 @@ import os from itertools import islice -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, List, Optional import numpy as np import torch @@ -137,13 +137,57 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: return token_sample +class ConcatenatedSequenceCollatorWrapper: + """Collator wrapper to add sequence_id to batch.""" + + def __init__(self, + base_collator: Callable, + eos_token_id: Optional[int] = None, + bos_token_id: Optional[int] = None): + self.base_collator = base_collator + if (eos_token_id is None) and (bos_token_id is None): + raise ValueError( + 'Must supply a value for either eos_token_id or bos_token_id, but got None for both.' + ) + if (eos_token_id is not None) and (bos_token_id is not None): + raise ValueError( + 'Cannot use *both* EOS and BOS tokens for detecting sequence boundaries. ' +\ + 'Please supply `eos_token_id` if sequences end with an EOS token, or use ' +\ + '`bos_token_id` if sequences start with a BOS token.' + ) + if eos_token_id is None: + self.split_token_id = bos_token_id + self.bos_mode = True + else: + self.split_token_id = eos_token_id + self.bos_mode = False + + def __call__(self, examples: List[Any]) -> Dict[str, torch.Tensor]: + batch = self.base_collator(examples) + batch['sequence_id'] = self.get_sequence_id_from_batch(batch) + return batch + + def get_sequence_id_from_batch( + self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + is_separator = torch.eq(batch['input_ids'], self.split_token_id) + cumulative_sep = torch.cumsum(is_separator, + dim=1).to(batch['input_ids'].dtype) + # If separator token is bos, we're already done + if self.bos_mode: + return cumulative_sep + + # If separator token is eos, right shift 1 space + left_zeros = cumulative_sep.new_zeros((cumulative_sep.shape[0], 1)) + return torch.cat([left_zeros, cumulative_sep[:, :-1]], dim=1) + + def build_text_dataloader(cfg: DictConfig, device_batch_size: int): assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}' if cfg.dataset.get('group_method', None) is not None: raise NotImplementedError( 'group_method is deprecated and has been removed.\nTo ' + 'concatenate, use the --concat_tokens ' + - 'argument when creating your MDS dataset with concat_c4.py') + 'argument when creating your MDS dataset with convert_dataset.py') dataset = StreamingTextDataset( local=cfg.dataset.local, tokenizer_name=cfg.dataset.tokenizer_name, @@ -166,6 +210,15 @@ def build_text_dataloader(cfg: DictConfig, device_batch_size: int): mlm=mlm_probability is not None, mlm_probability=mlm_probability) + eos_token_id = cfg.dataset.get('eos_token_id') + bos_token_id = cfg.dataset.get('bos_token_id') + if (eos_token_id is not None) or (bos_token_id is not None): + # Note: Will raise an error if both are non-None + collate_fn = ConcatenatedSequenceCollatorWrapper( + base_collator=collate_fn, + eos_token_id=eos_token_id, + bos_token_id=bos_token_id) + return DataLoader( dataset, collate_fn=collate_fn, diff --git a/examples/llm/src/models/layers/attention.py b/examples/llm/src/models/layers/attention.py index 8d423b5dc..ff80abe7d 100644 --- a/examples/llm/src/models/layers/attention.py +++ b/examples/llm/src/models/layers/attention.py @@ -203,9 +203,9 @@ def triton_flash_attn_fn( if key_padding_mask is not None: warnings.warn( - 'Propogating key_padding_mask to the attention module ' +\ + 'Propagating key_padding_mask to the attention module ' +\ 'and applying it within the attention module can cause ' +\ - 'unneccessary computation/memory usage. Consider integrating ' +\ + 'unnecessary computation/memory usage. Consider integrating ' +\ 'into attn_bias once and passing that to each attention ' +\ 'module instead.' ) @@ -346,15 +346,16 @@ def forward(self, return self.out_proj(context), attn_weights, past_key_value -def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal): +def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, + use_sequence_id): if attn_impl == 'flash': return None elif attn_impl in ['torch', 'triton']: if alibi: - if prefix_lm or not causal: + if (prefix_lm or not causal) or use_sequence_id: return (1, n_heads, seq_len, seq_len) return (1, n_heads, 1, seq_len) - elif prefix_lm: + elif prefix_lm or use_sequence_id: return (1, 1, seq_len, seq_len) return None else: diff --git a/examples/llm/src/models/mosaic_gpt/configuration_mosaic_gpt.py b/examples/llm/src/models/mosaic_gpt/configuration_mosaic_gpt.py index 2ac86c363..284596d52 100644 --- a/examples/llm/src/models/mosaic_gpt/configuration_mosaic_gpt.py +++ b/examples/llm/src/models/mosaic_gpt/configuration_mosaic_gpt.py @@ -28,6 +28,7 @@ def __init__( attn_clip_qkv: Optional[float] = None, softmax_scale: Optional[float] = None, prefix_lm: Optional[bool] = False, + attn_uses_sequence_id: Optional[bool] = False, alibi: bool = False, alibi_bias_max: int = 8, init_device: str = 'cpu', @@ -70,6 +71,10 @@ def __init__( prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix can attend to one another bi-directionally. Tokens outside the prefix use causal attention. + attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id. + When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates + which sub-sequence each token belongs to. + Defaults to ``False`` meaning any provided `sequence_id` will be ignored. alibi (bool): Whether to use the alibi bias instead of position embeddings. alibi_bias_max (int): The maximum value of the alibi bias. init_device (str): The device to use for parameter initialization. @@ -104,6 +109,7 @@ def __init__( self.attn_clip_qkv = attn_clip_qkv self.softmax_scale = softmax_scale self.prefix_lm = prefix_lm + self.attn_uses_sequence_id = attn_uses_sequence_id self.alibi = alibi self.alibi_bias_max = alibi_bias_max self.init_device = init_device @@ -145,6 +151,12 @@ def _validate_config(self): if self.alibi and self.attn_impl not in ['torch', 'triton']: raise NotImplementedError( 'alibi only implemented with torch and triton attention.') + if self.attn_uses_sequence_id and self.attn_impl not in [ + 'torch', 'triton' + ]: + raise NotImplementedError( + 'attn_uses_sequence_id only implemented with torch and triton attention.' + ) if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError( 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!' diff --git a/examples/llm/src/models/mosaic_gpt/mosaic_gpt.py b/examples/llm/src/models/mosaic_gpt/mosaic_gpt.py index 671becbf5..083cee3b5 100644 --- a/examples/llm/src/models/mosaic_gpt/mosaic_gpt.py +++ b/examples/llm/src/models/mosaic_gpt/mosaic_gpt.py @@ -41,6 +41,7 @@ def __init__(self, config: MosaicGPTConfig): self.attn_impl = config.attn_impl self.prefix_lm = config.prefix_lm + self.attn_uses_sequence_id = config.attn_uses_sequence_id self.alibi = config.alibi self.alibi_bias_max = config.alibi_bias_max @@ -107,7 +108,8 @@ def __init__(self, config: MosaicGPTConfig): config.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, - causal=self.is_causal) + causal=self.is_causal, + use_sequence_id=self.attn_uses_sequence_id) if config.no_bias: for module in self.modules(): @@ -120,11 +122,13 @@ def __init__(self, config: MosaicGPTConfig): if config.verbose and config.verbose > 2: print(self) + @torch.no_grad() def _attn_bias(self, device, dtype, attention_mask: Optional[torch.ByteTensor] = None, - prefix_mask: Optional[torch.ByteTensor] = None): + prefix_mask: Optional[torch.ByteTensor] = None, + sequence_id: Optional[torch.LongTensor] = None): if not self._attn_bias_initialized: if self.attn_bias_shape: self.attn_bias = torch.zeros(self.attn_bias_shape, @@ -153,8 +157,13 @@ def _attn_bias(self, assert isinstance(prefix_mask, torch.Tensor) # pyright attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask) + # If using torch or triton, we incorporate sequence_id (if appropriate) + if self.attn_uses_sequence_id and sequence_id is not None: + assert isinstance(attn_bias, torch.Tensor) # pyright + attn_bias = self._apply_sequence_id(attn_bias, sequence_id) + # If using torch or triton, we incorporate attention_mask. This will output - # None in place of attention_mask since it will not be futher needed in the + # None in place of attention_mask since it will not be further needed in the # attention modules. if attention_mask is not None: s_k = attention_mask.shape[-1] @@ -209,12 +218,34 @@ def _apply_prefix_mask(self, attn_bias: torch.Tensor, return attn_bias + def _apply_sequence_id(self, attn_bias: torch.Tensor, + sequence_id: torch.LongTensor): + seq_len = sequence_id.shape[-1] + if seq_len > self.config.max_seq_len: + raise ValueError( + f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}' + ) + + # select seq_len subset of attn mask + attn_bias = attn_bias[..., :seq_len, :seq_len] + + # Restrict attention to tokens that share the same value + # in sequence_id + cannot_attend = torch.logical_not( + torch.eq(sequence_id.view(-1, seq_len, 1), + sequence_id.view(-1, 1, seq_len))).unsqueeze(1) + min_val = torch.finfo(attn_bias.dtype).min + attn_bias = attn_bias.masked_fill(cannot_attend, min_val) + + return attn_bias + def forward( self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, attention_mask: Optional[torch.ByteTensor] = None, prefix_mask: Optional[torch.ByteTensor] = None, + sequence_id: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -242,6 +273,19 @@ def forward( 'prefix_mask is a required argument when MosaicGPT is configured with prefix_lm=True.' ) + if self.training: + if self.attn_uses_sequence_id and sequence_id is None: + raise ValueError( + 'sequence_id is a required argument when MosaicGPT is configured with attn_uses_sequence_id=True ' +\ + 'and the model is in train mode.' + ) + elif (self.attn_uses_sequence_id is False) and (sequence_id + is not None): + warnings.warn( + 'MosaicGPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' +\ + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.' + ) + S = input_ids.size(1) assert ( @@ -294,7 +338,8 @@ def forward( device=x.device, dtype=x.dtype, attention_mask=attention_mask, - prefix_mask=prefix_mask) + prefix_mask=prefix_mask, + sequence_id=sequence_id) # initialize the past key values cache if it should be used if use_cache and past_key_values is None: @@ -364,6 +409,11 @@ def prepare_inputs_for_generation(self, raise NotImplementedError( 'MosaicGPT does not support generation with right padding.') + if self.attn_uses_sequence_id and self.training: + sequence_id = torch.zeros_like(input_ids[:1]) + else: + sequence_id = None + if past_key_values is not None: input_ids = input_ids[:, -1].unsqueeze(-1) @@ -382,6 +432,7 @@ def prepare_inputs_for_generation(self, 'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, + 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache'), } @@ -477,12 +528,14 @@ def forward(self, batch): input_ids = batch['input_ids'] attention_mask = batch['attention_mask'].bool( ) if 'attention_mask' in batch else None + sequence_id = batch.get('sequence_id', None) prefix_mask = batch['bidirectional_mask'].bool( ) if 'bidirectional_mask' in batch else None # Note: prefix_mask is only used if model.prefix_lm is True return self.model(input_ids=input_ids, attention_mask=attention_mask, - prefix_mask=prefix_mask) + prefix_mask=prefix_mask, + sequence_id=sequence_id) def loss(self, outputs, batch): targets = self.get_targets(batch) diff --git a/examples/llm/tests/test_dataloader.py b/examples/llm/tests/test_dataloader.py index a76327422..ad739f18c 100644 --- a/examples/llm/tests/test_dataloader.py +++ b/examples/llm/tests/test_dataloader.py @@ -8,7 +8,8 @@ import torch from omegaconf import OmegaConf as om -from examples.common.text_data import build_text_dataloader +from examples.common.text_data import (ConcatenatedSequenceCollatorWrapper, + build_text_dataloader) from examples.llm.src import build_text_denoising_dataloader @@ -73,6 +74,29 @@ def test_correct_padding(tokenizer_name, pretokenize, batch_size=4): assert torch.equal(a, b) +@pytest.mark.parametrize(('eos_token_id', 'bos_token_id'), + [(5, None), (None, 5), + pytest.param(5, 5, marks=pytest.mark.xfail)]) +def test_sequence_id_wrapper(eos_token_id, bos_token_id): + wrapper = ConcatenatedSequenceCollatorWrapper( + lambda x: x, # placeholder + eos_token_id=eos_token_id, + bos_token_id=bos_token_id, + ) + + batch = {'input_ids': torch.Tensor([[0, 1, 2, 5, 0, 1, 5, 0, 6]])} + sequence_id = wrapper.get_sequence_id_from_batch(batch) + + if eos_token_id is not None: + assert torch.equal(sequence_id, + torch.Tensor([[0, 0, 0, 0, 1, 1, 1, 2, 2]])) + elif bos_token_id is not None: + assert torch.equal(sequence_id, + torch.Tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2]])) + else: + raise NotImplementedError() + + @pytest.mark.parametrize('decoder_only_format', [True, False]) @pytest.mark.parametrize('pretokenize', [True, False]) def test_denoising_dataloader(decoder_only_format, pretokenize): diff --git a/examples/llm/tests/test_flash_triton_torch.py b/examples/llm/tests/test_flash_triton_torch.py index a74fddd2a..220ae8150 100644 --- a/examples/llm/tests/test_flash_triton_torch.py +++ b/examples/llm/tests/test_flash_triton_torch.py @@ -62,6 +62,7 @@ def gen_bias(attn_impl): s, alibi, prefix_lm=False, + use_sequence_id=False, causal=causal) if bs is not None: attn_bias = torch.zeros(*bs, device=device) diff --git a/examples/llm/tests/test_model.py b/examples/llm/tests/test_model.py index 3c90e6129..24066a3c7 100644 --- a/examples/llm/tests/test_model.py +++ b/examples/llm/tests/test_model.py @@ -567,6 +567,57 @@ def test_forward_with_padding(attention_impl, device, alibi): atol=1e-6 if attention_impl == 'torch' else 1e-8) +@pytest.mark.parametrize('attention_impl', ['torch', 'triton']) +def test_advanced_mask_building(attention_impl): + # Test that the correct attention mask is created when both + # prefix_mask and sequence_id are used + hf_config = MosaicGPTConfig(init_device='cpu', + d_model=16, + n_heads=1, + n_layers=1, + mlp_ratio=1, + max_seq_len=256, + emb_pdrop=0.0, + resid_pdrop=0.0, + attn_impl=attention_impl, + prefix_lm=True, + attn_uses_sequence_id=True, + alibi=False) + mosaic_gpt = MosaicGPT(hf_config) + mosaic_gpt.eval() + + prefix_mask = torch.ByteTensor([[1, 1, 0, 0, 1, 1, 1, 0]]) + sequence_id = torch.LongTensor([[0, 0, 0, 0, 1, 1, 1, 1]]) + + attn_bias, _ = mosaic_gpt._attn_bias(device=mosaic_gpt.device, + dtype=torch.float32, + attention_mask=None, + prefix_mask=prefix_mask, + sequence_id=sequence_id) + + assert isinstance(attn_bias, torch.Tensor) + assert attn_bias.shape == torch.Size([1, 1, 8, 8]) + + # We'll construct the expected value of attn_bias and then compare. + can_attend = torch.tensor([ + [1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 1, 1, 1], + ]) + can_attend = can_attend.bool().view(1, 1, 8, 8) + expected_attn_bias = torch.zeros_like(attn_bias) + expected_attn_bias = expected_attn_bias.masked_fill( + torch.logical_not(can_attend), + torch.finfo(attn_bias.dtype).min) + + assert torch.equal(attn_bias, expected_attn_bias) + + @pytest.mark.parametrize('attention_impl,device', [('torch', 'cpu'), ('flash', 'gpu'), ('triton', 'gpu'),