Skip to content

关于attention的选择机制 #74

Open
@Lubenwei-nb123

Description

@Lubenwei-nb123

attention.py下,你们的默认选择优先级是fa3>xformers>sdpa>fa2>eager attn,我大致的理解是sdpa是在fa和xformers里面选的,所以为了跳过选择kernel的开销,你们直接把优先级先确定下来,但是我不太确定为什么sdpa的优先级比fa2要高,我个人认为的优先级是fa3>>fa2>xformers>sdpa>eager attn,想请教一下。

def attention(
    q,
    k,
    v,
    attn_impl: Optional[str] = None,
    attn_mask: Optional[torch.Tensor] = None,
    scale: Optional[float] = None,
):
    """
    q: [B, Lq, Nq, C1]
    k: [B, Lk, Nk, C1]
    v: [B, Lk, Nk, C2]
    """
    assert attn_impl in [
        None,
        "auto",
        "eager",
        "flash_attn_2",
        "flash_attn_3",
        "xformers",
        "sdpa",
        "sage_attn",
        "sparge_attn",
    ]
    if attn_impl is None or attn_impl == "auto":
        if FLASH_ATTN_3_AVAILABLE:
            return flash_attn3(q, k, v, softmax_scale=scale)
        elif XFORMERS_AVAILABLE:
            return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
        elif SDPA_AVAILABLE:
            return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
        elif FLASH_ATTN_2_AVAILABLE:
            return flash_attn2(q, k, v, softmax_scale=scale)
        else:
            return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
    else:
        if attn_impl == "eager":
            return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
        elif attn_impl == "flash_attn_3":
            return flash_attn3(q, k, v, softmax_scale=scale)
        elif attn_impl == "flash_attn_2":
            return flash_attn2(q, k, v, softmax_scale=scale)
        elif attn_impl == "xformers":
            return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
        elif attn_impl == "sdpa":
            return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
        elif attn_impl == "sage_attn":
            return sage_attn(q, k, v, attn_mask=attn_mask, scale=scale)
        elif attn_impl == "sparge_attn":
            return sparge_attn(q, k, v, attn_mask=attn_mask, scale=scale)
        else:
            raise ValueError(f"Invalid attention implementation: {attn_impl}")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions