Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding flash attention to GPT2 #27479

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/source/en/model_doc/gpt2.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,22 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
- [Token classification task guide](../tasks/token_classification)
- [Causal language modeling task guide](../tasks/language_modeling)

### Using Flash Attention 2
Flash Attention 2 is an advanced optimization method that dramatically reduces memory usage and increases inference speed. It's particularly effective for large-scale generation tasks. To utilize Flash Attention 2, ensure your hardware is compatible and install the necessary package with:

```python
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be ```sh here. Not Python script.

pip install -U flash-attn --no-build-isolation
```

Use the model with Flash Attention 2 as follows:

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The leading space should be removed otherwise it causes syntax error.

device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
```

## GPT2Config

[[autodoc]] GPT2Config
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ FlashAttention-2 is experimental and may change considerably in future versions.
1. additionally parallelizing the attention computation over sequence length
2. partitioning the work between GPU threads to reduce communication and shared memory reads/writes between them

FlashAttention-2 supports inference with Llama, Mistral, Falcon and Bark models. You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
FlashAttention-2 supports inference with Llama, Mistral, Falcon, Bark, and GPT2 models. You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.

Before you begin, make sure you have FlashAttention-2 installed (see the [installation](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features) guide for more details about prerequisites):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)

self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.is_causal = True
self.resid_dropout = nn.Dropout(config.resid_pdrop)

self.pruned_heads = set()
Expand Down Expand Up @@ -264,7 +265,8 @@ def _split_heads(self, tensor, num_heads, attn_head_size):
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
# (batch, head, seq_length, head_features)
return tensor.permute(0, 2, 1, 3)

def _merge_heads(self, tensor, num_heads, attn_head_size):
"""
Expand Down Expand Up @@ -328,6 +330,73 @@ def forward(
return outputs # a, present, (attentions)


class DecisionTransformerGPT2FlashAttention(DecisionTransformerGPT2Attention):
"""
DecisionTransformerGPT2FlashAttention inherits from `DecisionTransformerGPT2Attention` as the weights of the module
stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
API of flash attention and deal with padding tokens in case the input contains any of them.
"""

def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
# Prepare query, key, and value
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)

if use_cache is True:
present = (key, value)
else:
present = None

# Apply Flash Attention Forward
if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:
# Flash Attention forward pass
attn_output = self._flash_attention_forward(
query, key, value, attention_mask, query.size(-2), self.attn_dropout.p, softmax_scale=None
)

# Merge heads and project back to hidden size
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)

return outputs


# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->DecisionTransformerGPT2
class DecisionTransformerGPT2MLP(nn.Module):
def __init__(self, intermediate_size, config):
Expand All @@ -354,12 +423,18 @@ def __init__(self, config, layer_idx=None):
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size

self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = DecisionTransformerGPT2Attention(config, layer_idx=layer_idx)
self.attn = (
DecisionTransformerGPT2Attention(config, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else DecisionTransformerGPT2FlashAttention(config, layer_idx=layer_idx)
)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

if config.add_cross_attention:
self.crossattention = DecisionTransformerGPT2Attention(
config, is_cross_attention=True, layer_idx=layer_idx
self.crossattention = (
DecisionTransformerGPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else DecisionTransformerGPT2FlashAttention(config, is_cross_attention=True, layer_idx=layer_idx)
)
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

Expand Down Expand Up @@ -411,7 +486,8 @@ def forward(
attn_output = cross_attn_outputs[0]
# residual connection
hidden_states = residual + attn_output
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
# add cross attentions if we output attention weights
outputs = outputs + cross_attn_outputs[2:]

residual = hidden_states
hidden_states = self.ln_2(hidden_states)
Expand All @@ -424,7 +500,8 @@ def forward(
else:
outputs = (hidden_states,) + outputs[1:]

return outputs # hidden_states, present, (attentions, cross_attentions)
# hidden_states, present, (attentions, cross_attentions)
return outputs


class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
Expand Down Expand Up @@ -551,35 +628,38 @@ def forward(
position_ids = position_ids.unsqueeze(0)

# GPT2Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]

# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
if getattr(self.config, "_flash_attn_2_enabled", False):
# 2d mask is passed through the layers
attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None
encoder_attention_mask = (
encoder_attention_mask.bool()
if (encoder_attention_mask is not None and 0 in encoder_attention_mask)
else None
)
else:
encoder_attention_mask = None
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask[:, None, None, :]

# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down Expand Up @@ -922,9 +1002,12 @@ def forward(
x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)

# get predictions
return_preds = self.predict_return(x[:, 2]) # predict next return given state and action
state_preds = self.predict_state(x[:, 2]) # predict next state given state and action
action_preds = self.predict_action(x[:, 1]) # predict next action given state
# predict next return given state and action
return_preds = self.predict_return(x[:, 2])
# predict next state given state and action
state_preds = self.predict_state(x[:, 2])
# predict next action given state
action_preds = self.predict_action(x[:, 1])
if not return_dict:
return (state_preds, action_preds, return_preds)

Expand Down
Loading