Open
Description
在attention.py下,你们的默认选择优先级是fa3>xformers>sdpa>fa2>eager attn,我大致的理解是sdpa是在fa和xformers里面选的,所以为了跳过选择kernel的开销,你们直接把优先级先确定下来,但是我不太确定为什么sdpa的优先级比fa2要高,我个人认为的优先级是fa3>>fa2>xformers>sdpa>eager attn,想请教一下。
def attention(
q,
k,
v,
attn_impl: Optional[str] = None,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
):
"""
q: [B, Lq, Nq, C1]
k: [B, Lk, Nk, C1]
v: [B, Lk, Nk, C2]
"""
assert attn_impl in [
None,
"auto",
"eager",
"flash_attn_2",
"flash_attn_3",
"xformers",
"sdpa",
"sage_attn",
"sparge_attn",
]
if attn_impl is None or attn_impl == "auto":
if FLASH_ATTN_3_AVAILABLE:
return flash_attn3(q, k, v, softmax_scale=scale)
elif XFORMERS_AVAILABLE:
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
elif SDPA_AVAILABLE:
return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
elif FLASH_ATTN_2_AVAILABLE:
return flash_attn2(q, k, v, softmax_scale=scale)
else:
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
else:
if attn_impl == "eager":
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
elif attn_impl == "flash_attn_3":
return flash_attn3(q, k, v, softmax_scale=scale)
elif attn_impl == "flash_attn_2":
return flash_attn2(q, k, v, softmax_scale=scale)
elif attn_impl == "xformers":
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
elif attn_impl == "sdpa":
return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
elif attn_impl == "sage_attn":
return sage_attn(q, k, v, attn_mask=attn_mask, scale=scale)
elif attn_impl == "sparge_attn":
return sparge_attn(q, k, v, attn_mask=attn_mask, scale=scale)
else:
raise ValueError(f"Invalid attention implementation: {attn_impl}")
Metadata
Metadata
Assignees
Labels
No labels