Skip to content

Commit

Permalink
Support only-within-sequence attention for MosaicGPT (#266)
Browse files Browse the repository at this point in the history
Makes it possible to have attention restricted to tokens within the same source sequence when using pre-concatenated text dataloading.
  • Loading branch information
alextrott16 authored Mar 30, 2023
1 parent daf5a79 commit a80f35d
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 13 deletions.
57 changes: 55 additions & 2 deletions examples/common/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import os
from itertools import islice
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -137,13 +137,57 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
return token_sample


class ConcatenatedSequenceCollatorWrapper:
"""Collator wrapper to add sequence_id to batch."""

def __init__(self,
base_collator: Callable,
eos_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None):
self.base_collator = base_collator
if (eos_token_id is None) and (bos_token_id is None):
raise ValueError(
'Must supply a value for either eos_token_id or bos_token_id, but got None for both.'
)
if (eos_token_id is not None) and (bos_token_id is not None):
raise ValueError(
'Cannot use *both* EOS and BOS tokens for detecting sequence boundaries. ' +\
'Please supply `eos_token_id` if sequences end with an EOS token, or use ' +\
'`bos_token_id` if sequences start with a BOS token.'
)
if eos_token_id is None:
self.split_token_id = bos_token_id
self.bos_mode = True
else:
self.split_token_id = eos_token_id
self.bos_mode = False

def __call__(self, examples: List[Any]) -> Dict[str, torch.Tensor]:
batch = self.base_collator(examples)
batch['sequence_id'] = self.get_sequence_id_from_batch(batch)
return batch

def get_sequence_id_from_batch(
self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
is_separator = torch.eq(batch['input_ids'], self.split_token_id)
cumulative_sep = torch.cumsum(is_separator,
dim=1).to(batch['input_ids'].dtype)
# If separator token is bos, we're already done
if self.bos_mode:
return cumulative_sep

# If separator token is eos, right shift 1 space
left_zeros = cumulative_sep.new_zeros((cumulative_sep.shape[0], 1))
return torch.cat([left_zeros, cumulative_sep[:, :-1]], dim=1)


def build_text_dataloader(cfg: DictConfig, device_batch_size: int):
assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}'
if cfg.dataset.get('group_method', None) is not None:
raise NotImplementedError(
'group_method is deprecated and has been removed.\nTo ' +
'concatenate, use the --concat_tokens ' +
'argument when creating your MDS dataset with concat_c4.py')
'argument when creating your MDS dataset with convert_dataset.py')
dataset = StreamingTextDataset(
local=cfg.dataset.local,
tokenizer_name=cfg.dataset.tokenizer_name,
Expand All @@ -166,6 +210,15 @@ def build_text_dataloader(cfg: DictConfig, device_batch_size: int):
mlm=mlm_probability is not None,
mlm_probability=mlm_probability)

eos_token_id = cfg.dataset.get('eos_token_id')
bos_token_id = cfg.dataset.get('bos_token_id')
if (eos_token_id is not None) or (bos_token_id is not None):
# Note: Will raise an error if both are non-None
collate_fn = ConcatenatedSequenceCollatorWrapper(
base_collator=collate_fn,
eos_token_id=eos_token_id,
bos_token_id=bos_token_id)

return DataLoader(
dataset,
collate_fn=collate_fn,
Expand Down
11 changes: 6 additions & 5 deletions examples/llm/src/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ def triton_flash_attn_fn(

if key_padding_mask is not None:
warnings.warn(
'Propogating key_padding_mask to the attention module ' +\
'Propagating key_padding_mask to the attention module ' +\
'and applying it within the attention module can cause ' +\
'unneccessary computation/memory usage. Consider integrating ' +\
'unnecessary computation/memory usage. Consider integrating ' +\
'into attn_bias once and passing that to each attention ' +\
'module instead.'
)
Expand Down Expand Up @@ -346,15 +346,16 @@ def forward(self,
return self.out_proj(context), attn_weights, past_key_value


def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal):
def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal,
use_sequence_id):
if attn_impl == 'flash':
return None
elif attn_impl in ['torch', 'triton']:
if alibi:
if prefix_lm or not causal:
if (prefix_lm or not causal) or use_sequence_id:
return (1, n_heads, seq_len, seq_len)
return (1, n_heads, 1, seq_len)
elif prefix_lm:
elif prefix_lm or use_sequence_id:
return (1, 1, seq_len, seq_len)
return None
else:
Expand Down
12 changes: 12 additions & 0 deletions examples/llm/src/models/mosaic_gpt/configuration_mosaic_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
attn_clip_qkv: Optional[float] = None,
softmax_scale: Optional[float] = None,
prefix_lm: Optional[bool] = False,
attn_uses_sequence_id: Optional[bool] = False,
alibi: bool = False,
alibi_bias_max: int = 8,
init_device: str = 'cpu',
Expand Down Expand Up @@ -70,6 +71,10 @@ def __init__(
prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
which sub-sequence each token belongs to.
Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
alibi (bool): Whether to use the alibi bias instead of position embeddings.
alibi_bias_max (int): The maximum value of the alibi bias.
init_device (str): The device to use for parameter initialization.
Expand Down Expand Up @@ -104,6 +109,7 @@ def __init__(
self.attn_clip_qkv = attn_clip_qkv
self.softmax_scale = softmax_scale
self.prefix_lm = prefix_lm
self.attn_uses_sequence_id = attn_uses_sequence_id
self.alibi = alibi
self.alibi_bias_max = alibi_bias_max
self.init_device = init_device
Expand Down Expand Up @@ -145,6 +151,12 @@ def _validate_config(self):
if self.alibi and self.attn_impl not in ['torch', 'triton']:
raise NotImplementedError(
'alibi only implemented with torch and triton attention.')
if self.attn_uses_sequence_id and self.attn_impl not in [
'torch', 'triton'
]:
raise NotImplementedError(
'attn_uses_sequence_id only implemented with torch and triton attention.'
)
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
raise ValueError(
'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!'
Expand Down
63 changes: 58 additions & 5 deletions examples/llm/src/models/mosaic_gpt/mosaic_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self, config: MosaicGPTConfig):

self.attn_impl = config.attn_impl
self.prefix_lm = config.prefix_lm
self.attn_uses_sequence_id = config.attn_uses_sequence_id
self.alibi = config.alibi
self.alibi_bias_max = config.alibi_bias_max

Expand Down Expand Up @@ -107,7 +108,8 @@ def __init__(self, config: MosaicGPTConfig):
config.max_seq_len,
self.alibi,
prefix_lm=self.prefix_lm,
causal=self.is_causal)
causal=self.is_causal,
use_sequence_id=self.attn_uses_sequence_id)

if config.no_bias:
for module in self.modules():
Expand All @@ -120,11 +122,13 @@ def __init__(self, config: MosaicGPTConfig):
if config.verbose and config.verbose > 2:
print(self)

@torch.no_grad()
def _attn_bias(self,
device,
dtype,
attention_mask: Optional[torch.ByteTensor] = None,
prefix_mask: Optional[torch.ByteTensor] = None):
prefix_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[torch.LongTensor] = None):
if not self._attn_bias_initialized:
if self.attn_bias_shape:
self.attn_bias = torch.zeros(self.attn_bias_shape,
Expand Down Expand Up @@ -153,8 +157,13 @@ def _attn_bias(self,
assert isinstance(prefix_mask, torch.Tensor) # pyright
attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)

# If using torch or triton, we incorporate sequence_id (if appropriate)
if self.attn_uses_sequence_id and sequence_id is not None:
assert isinstance(attn_bias, torch.Tensor) # pyright
attn_bias = self._apply_sequence_id(attn_bias, sequence_id)

# If using torch or triton, we incorporate attention_mask. This will output
# None in place of attention_mask since it will not be futher needed in the
# None in place of attention_mask since it will not be further needed in the
# attention modules.
if attention_mask is not None:
s_k = attention_mask.shape[-1]
Expand Down Expand Up @@ -209,12 +218,34 @@ def _apply_prefix_mask(self, attn_bias: torch.Tensor,

return attn_bias

def _apply_sequence_id(self, attn_bias: torch.Tensor,
sequence_id: torch.LongTensor):
seq_len = sequence_id.shape[-1]
if seq_len > self.config.max_seq_len:
raise ValueError(
f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
)

# select seq_len subset of attn mask
attn_bias = attn_bias[..., :seq_len, :seq_len]

# Restrict attention to tokens that share the same value
# in sequence_id
cannot_attend = torch.logical_not(
torch.eq(sequence_id.view(-1, seq_len, 1),
sequence_id.view(-1, 1, seq_len))).unsqueeze(1)
min_val = torch.finfo(attn_bias.dtype).min
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)

return attn_bias

def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
prefix_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
Expand Down Expand Up @@ -242,6 +273,19 @@ def forward(
'prefix_mask is a required argument when MosaicGPT is configured with prefix_lm=True.'
)

if self.training:
if self.attn_uses_sequence_id and sequence_id is None:
raise ValueError(
'sequence_id is a required argument when MosaicGPT is configured with attn_uses_sequence_id=True ' +\
'and the model is in train mode.'
)
elif (self.attn_uses_sequence_id is False) and (sequence_id
is not None):
warnings.warn(
'MosaicGPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' +\
'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.'
)

S = input_ids.size(1)

assert (
Expand Down Expand Up @@ -294,7 +338,8 @@ def forward(
device=x.device,
dtype=x.dtype,
attention_mask=attention_mask,
prefix_mask=prefix_mask)
prefix_mask=prefix_mask,
sequence_id=sequence_id)

# initialize the past key values cache if it should be used
if use_cache and past_key_values is None:
Expand Down Expand Up @@ -364,6 +409,11 @@ def prepare_inputs_for_generation(self,
raise NotImplementedError(
'MosaicGPT does not support generation with right padding.')

if self.attn_uses_sequence_id and self.training:
sequence_id = torch.zeros_like(input_ids[:1])
else:
sequence_id = None

if past_key_values is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)

Expand All @@ -382,6 +432,7 @@ def prepare_inputs_for_generation(self,
'input_ids': input_ids,
'attention_mask': attention_mask,
'prefix_mask': prefix_mask,
'sequence_id': sequence_id,
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
}
Expand Down Expand Up @@ -477,12 +528,14 @@ def forward(self, batch):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask'].bool(
) if 'attention_mask' in batch else None
sequence_id = batch.get('sequence_id', None)
prefix_mask = batch['bidirectional_mask'].bool(
) if 'bidirectional_mask' in batch else None
# Note: prefix_mask is only used if model.prefix_lm is True
return self.model(input_ids=input_ids,
attention_mask=attention_mask,
prefix_mask=prefix_mask)
prefix_mask=prefix_mask,
sequence_id=sequence_id)

def loss(self, outputs, batch):
targets = self.get_targets(batch)
Expand Down
26 changes: 25 additions & 1 deletion examples/llm/tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import torch
from omegaconf import OmegaConf as om

from examples.common.text_data import build_text_dataloader
from examples.common.text_data import (ConcatenatedSequenceCollatorWrapper,
build_text_dataloader)
from examples.llm.src import build_text_denoising_dataloader


Expand Down Expand Up @@ -73,6 +74,29 @@ def test_correct_padding(tokenizer_name, pretokenize, batch_size=4):
assert torch.equal(a, b)


@pytest.mark.parametrize(('eos_token_id', 'bos_token_id'),
[(5, None), (None, 5),
pytest.param(5, 5, marks=pytest.mark.xfail)])
def test_sequence_id_wrapper(eos_token_id, bos_token_id):
wrapper = ConcatenatedSequenceCollatorWrapper(
lambda x: x, # placeholder
eos_token_id=eos_token_id,
bos_token_id=bos_token_id,
)

batch = {'input_ids': torch.Tensor([[0, 1, 2, 5, 0, 1, 5, 0, 6]])}
sequence_id = wrapper.get_sequence_id_from_batch(batch)

if eos_token_id is not None:
assert torch.equal(sequence_id,
torch.Tensor([[0, 0, 0, 0, 1, 1, 1, 2, 2]]))
elif bos_token_id is not None:
assert torch.equal(sequence_id,
torch.Tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2]]))
else:
raise NotImplementedError()


@pytest.mark.parametrize('decoder_only_format', [True, False])
@pytest.mark.parametrize('pretokenize', [True, False])
def test_denoising_dataloader(decoder_only_format, pretokenize):
Expand Down
1 change: 1 addition & 0 deletions examples/llm/tests/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def gen_bias(attn_impl):
s,
alibi,
prefix_lm=False,
use_sequence_id=False,
causal=causal)
if bs is not None:
attn_bias = torch.zeros(*bs, device=device)
Expand Down
Loading

0 comments on commit a80f35d

Please sign in to comment.