Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import numpy as np
import torch
import torch.nn.functional as F

from typing import Optional
from torch import nn
from packaging.version import Version

from megatron import get_timers, get_args, get_retro_args, core, get_num_microbatches
from megatron.utils import print_rank_0
Expand Down Expand Up @@ -36,9 +35,15 @@
rearrange = None

try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
import flash_attn as _flash_attn
if Version(getattr(_flash_attn, "__version__", "1")) >= Version("2"):
from flash_attn.flash_attn_interface import flash_attn_func
FLASH_VERSION = 2
else:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
FLASH_VERSION = 1
except ImportError:
flash_attn_unpadded_func = None
FLASH_VERSION = None


""" We use the following notation throughout this file:
Expand Down Expand Up @@ -508,7 +513,7 @@ class FlashSelfAttention(torch.nn.Module):
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, '
assert FLASH_VERSION is not None, ('Please install FlashAttention first, '
'e.g., with pip install flash-attn')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
Expand All @@ -521,10 +526,31 @@ def forward(self, q, k, v):
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""

assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
assert all((i.is_cuda for i in (q,k,v)))

if FLASH_VERSION==1:
return self._forward_v1(q,k,v)

seqlen_q, seqlen_k = q.shape[1], k.shape[1]

if self.training:
# during training q,k,v always have same seqlen
assert seqlen_k == seqlen_q
is_causal = self.causal
dropout_p = self.dropout_p
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = self.causal and (seqlen_q == seqlen_k)
dropout_p = 0

output = flash_attn_func(q, k, v, dropout_p,softmax_scale=self.softmax_scale, causal=is_causal)

return output


def _forward_v1(self, q, k, v):
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = k.shape[1]

Expand Down Expand Up @@ -647,7 +673,7 @@ def __init__(self, init_method,
self.checkpoint_core_attention = args.recompute_granularity == 'selective'

if self.use_flash_attn:
if flash_attn_unpadded_func is None:
if FLASH_VERSION is None:
raise ImportError('FlashAttention is not installed, please install with '
'pip install flash-attn')
assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
Expand Down Expand Up @@ -882,6 +908,7 @@ def forward(self, hidden_states, attention_mask,
sq, b, np, hn = query_layer.size()
# Expand kv to be compatible with flash-attn implementation
# [sq, b, 1, hn] -> [sq, b, np, hn]
# TODO: This should be skippable for flash 2, but getting illegal memory access.
key_layer = key_layer.expand((sq, b, np, hn))
value_layer = value_layer.expand((sq, b, np, hn))
q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
Expand Down