Skip to content

Commit

Permalink
Merge pull request facebookresearch#20 from tenpercent/develop
Browse files Browse the repository at this point in the history
Fix lints
  • Loading branch information
qianfengz authored Aug 17, 2024
2 parents 5be80a3 + d6b6456 commit cee0980
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 85 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def get_extensions():
"-Werror",
"-Woverloaded-virtual",
"-mllvm",
"-enable-post-misched=0"
"-enable-post-misched=0",
]
+ generator_flag
+ cc_flag,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@
if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability("cuda")
sm70_or_better_only = pytest.mark.skipif(
torch.version.cuda and compute_capability < (7, 0), reason="requires sm70+"
torch.version.cuda is not None and compute_capability < (7, 0), reason="requires sm70+"
)
sm75_or_better_only = pytest.mark.skipif(
torch.version.cuda and compute_capability < (7, 5), reason="requires sm75+"
torch.version.cuda is not None and compute_capability < (7, 5), reason="requires sm75+"
)
sm80_or_better_only = pytest.mark.skipif(
torch.version.cuda and compute_capability < (8, 0), reason="requires sm80+"
torch.version.cuda is not None and compute_capability < (8, 0), reason="requires sm80+"
)
skip_if_rocm = pytest.mark.skipif(
torch.version.hip is not None, reason="not supported on ROCm"
Expand Down
4 changes: 2 additions & 2 deletions xformers/attn_bias_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def create_attn_bias(
dtype,
requires_grad: bool,
fmt: str,
op: Type[AttentionOpBase],
op: Optional[Type[AttentionOpBase]] = None,
page_size: Optional[int] = None,
):
if bias_type is None or isinstance(None, bias_type):
Expand All @@ -59,7 +59,7 @@ def create_attn_bias(
* 3
)
attn_bias = attn_bias.expand(batch_size, num_heads, q_len, kv_len)
elif issubclass(op, fmha.triton_splitk.FwOp):
elif op is not None and issubclass(op, fmha.triton_splitk.FwOp):
attn_bias = (
torch.randn(
(batch_size, num_heads_groups, num_heads, q_len, kv_len),
Expand Down
177 changes: 108 additions & 69 deletions xformers/csrc/attention/hip_fmha/generate_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
{max_k}>({cap_mode}ForwardParams& param, hipStream_t stream);
"""

FMHA_INFER_INSTANCE_FNAME = "fmha_{mode}_infer_{dtype_str}_{has_or_no_causalmask_str}_"\
"{has_or_no_bias_str}_{has_or_no_dropout_str}_{max_k_str}.cpp"
FMHA_INFER_INSTANCE_FNAME = (
"fmha_{mode}_infer_{dtype_str}_{has_or_no_causalmask_str}_"
"{has_or_no_bias_str}_{has_or_no_dropout_str}_{max_k_str}.cpp"
)

FMHA_FORWARD_INSTANCE_TEMPLATE_INC = """
#include <ck_tile/core/numeric/{dtype_file}.hpp>
Expand All @@ -52,8 +54,10 @@
{max_k}>({cap_mode}ForwardParams& param, hipStream_t stream);
"""

FMHA_FORWARD_INSTANCE_FNAME = "fmha_{mode}_forward_{dtype_str}_{has_or_no_causalmask_str}_"\
"{has_or_no_bias_str}_{has_or_no_dropout_str}_{max_k_str}.cpp"
FMHA_FORWARD_INSTANCE_FNAME = (
"fmha_{mode}_forward_{dtype_str}_{has_or_no_causalmask_str}_"
"{has_or_no_bias_str}_{has_or_no_dropout_str}_{max_k_str}.cpp"
)

FMHA_BACKWARD_INSTANCE_TEMPLATE_INC = """
#include <ck_tile/core/numeric/{dtype_file}.hpp>
Expand All @@ -70,56 +74,55 @@
{max_k}>({cap_mode}BackwardParams& param, hipStream_t stream);
"""

FMHA_BACKWARD_INSTANCE_FNAME = "fmha_{mode}_backward_{dtype_str}_{has_or_no_causalmask_str}_"\
"{has_or_no_bias_str}_{has_or_no_biasgrad_str}_{has_or_no_dropout_str}_{max_k_str}.cpp"
FMHA_BACKWARD_INSTANCE_FNAME = (
"fmha_{mode}_backward_{dtype_str}_{has_or_no_causalmask_str}_"
"{has_or_no_bias_str}_{has_or_no_biasgrad_str}_{has_or_no_dropout_str}_{max_k_str}.cpp"
)

FMHA_INSTANCE_REF_FNAME = "instances/fmha_{mode}_{function}_{dtype}_instances_ref.h"

BOOL_MAP = {
True : "true",
False : "false"
}
BOOL_MAP = {True: "true", False: "false"}

BOOL_MAP_CAUSALMASK = {
True : "has_causalmask",
False : "no_causalmask",
True: "has_causalmask",
False: "no_causalmask",
}

BOOL_MAP_BIAS = {
True : "has_bias",
False : "no_bias",
True: "has_bias",
False: "no_bias",
}

BOOL_MAP_BIASGRAD = {
True : "has_biasgrad",
False : "no_biasgrad",
True: "has_biasgrad",
False: "no_biasgrad",
}

BOOL_MAP_DROPOUT = {
True : "has_dropout",
False : "no_dropout",
True: "has_dropout",
False: "no_dropout",
}

INT_MAP_MAX_K = {
32 : "maxk_32",
64 : "maxk_64",
128 : "maxk_128",
256 : "maxk_256",
32: "maxk_32",
64: "maxk_64",
128: "maxk_128",
256: "maxk_256",
}

TYPE_CTYPE_MAP = {
"fp16" : "ck_tile::fp16_t",
"bf16" : "ck_tile::bf16_t",
"fp16": "ck_tile::fp16_t",
"bf16": "ck_tile::bf16_t",
}

TYPE_FNAME_MAP = {
"fp16" : "half",
"bf16" : "bfloat16",
"fp16": "half",
"bf16": "bfloat16",
}

MODE_NAME_MAP = {
"batched" : "Batched",
"grouped" : "Grouped",
"batched": "Batched",
"grouped": "Grouped",
}


Expand All @@ -133,14 +136,18 @@ def create_infer_instances(instance_dir: Path, headdims: List) -> None:
fname = FMHA_INFER_INSTANCE_FNAME.format(
mode=mode,
dtype_str=dtype,
has_or_no_causalmask_str=BOOL_MAP_CAUSALMASK[has_causalmask],
has_or_no_causalmask_str=BOOL_MAP_CAUSALMASK[
has_causalmask
],
has_or_no_bias_str=BOOL_MAP_BIAS[has_bias],
has_or_no_dropout_str=BOOL_MAP_DROPOUT[has_dropout],
max_k_str=INT_MAP_MAX_K[max_k],
)
infer_instance_inc = FMHA_INFER_INSTANCE_TEMPLATE_INC.format(
mode=mode,
dtype_file=TYPE_FNAME_MAP[dtype],
infer_instance_inc = (
FMHA_INFER_INSTANCE_TEMPLATE_INC.format(
mode=mode,
dtype_file=TYPE_FNAME_MAP[dtype],
)
)
infer_instance = FMHA_INFER_INSTANCE_TEMPLATE.format(
extern="",
Expand All @@ -152,7 +159,11 @@ def create_infer_instances(instance_dir: Path, headdims: List) -> None:
max_k=max_k,
cap_mode=MODE_NAME_MAP[mode],
)
(instance_dir / fname).write_text(FMHA_COPYRIGHT_HEADER + infer_instance_inc + infer_instance)
(instance_dir / fname).write_text(
FMHA_COPYRIGHT_HEADER
+ infer_instance_inc
+ infer_instance
)


def create_infer_instances_ref(instance_dir: Path, headdims: List) -> None:
Expand All @@ -167,7 +178,7 @@ def create_infer_instances_ref(instance_dir: Path, headdims: List) -> None:
mode=mode,
dtype_file=TYPE_FNAME_MAP[dtype],
)
with open(ref_fname, 'a') as file:
with open(ref_fname, "a") as file:
file.write(FMHA_COPYRIGHT_HEADER)
file.write(infer_instance_inc)
for max_k in headdims:
Expand Down Expand Up @@ -197,15 +208,19 @@ def create_forward_instances(instance_dir: Path, headdims: List) -> None:
fname = FMHA_FORWARD_INSTANCE_FNAME.format(
mode=mode,
dtype_str=dtype,
has_or_no_causalmask_str=BOOL_MAP_CAUSALMASK[has_causalmask],
has_or_no_causalmask_str=BOOL_MAP_CAUSALMASK[
has_causalmask
],
has_or_no_bias_str=BOOL_MAP_BIAS[has_bias],
has_or_no_dropout_str=BOOL_MAP_DROPOUT[has_dropout],
max_k_str=INT_MAP_MAX_K[max_k],
)
forward_instance_inc = FMHA_FORWARD_INSTANCE_TEMPLATE_INC.format(
mode=mode,
dtype_file=TYPE_FNAME_MAP[dtype],
)
forward_instance_inc = (
FMHA_FORWARD_INSTANCE_TEMPLATE_INC.format(
mode=mode,
dtype_file=TYPE_FNAME_MAP[dtype],
)
)
forward_instance = FMHA_FORWARD_INSTANCE_TEMPLATE.format(
extern="",
mode=mode,
Expand All @@ -216,7 +231,11 @@ def create_forward_instances(instance_dir: Path, headdims: List) -> None:
max_k=max_k,
cap_mode=MODE_NAME_MAP[mode],
)
(instance_dir / fname).write_text(FMHA_COPYRIGHT_HEADER + forward_instance_inc + forward_instance)
(instance_dir / fname).write_text(
FMHA_COPYRIGHT_HEADER
+ forward_instance_inc
+ forward_instance
)


def create_forward_instances_ref(instance_dir: Path, headdims: List) -> None:
Expand All @@ -231,22 +250,24 @@ def create_forward_instances_ref(instance_dir: Path, headdims: List) -> None:
mode=mode,
dtype_file=TYPE_FNAME_MAP[dtype],
)
with open(ref_fname, 'a') as file:
with open(ref_fname, "a") as file:
file.write(FMHA_COPYRIGHT_HEADER)
file.write(forward_instance_inc)
for max_k in headdims:
for has_bias in [True, False]:
for has_dropout in [True, False]:
for has_causalmask in [True, False]:
forward_instance = FMHA_FORWARD_INSTANCE_TEMPLATE.format(
extern="extern ",
mode=mode,
dtype=TYPE_CTYPE_MAP[dtype],
has_causalmask=BOOL_MAP[has_causalmask],
has_bias=BOOL_MAP[has_bias],
has_dropout=BOOL_MAP[has_dropout],
max_k=max_k,
cap_mode=MODE_NAME_MAP[mode],
forward_instance = (
FMHA_FORWARD_INSTANCE_TEMPLATE.format(
extern="extern ",
mode=mode,
dtype=TYPE_CTYPE_MAP[dtype],
has_causalmask=BOOL_MAP[has_causalmask],
has_bias=BOOL_MAP[has_bias],
has_dropout=BOOL_MAP[has_dropout],
max_k=max_k,
cap_mode=MODE_NAME_MAP[mode],
)
)
file.write(forward_instance)

Expand All @@ -255,21 +276,29 @@ def create_backward_instances(instance_dir: Path, headdims: List) -> None:
for mode in ["batched", "grouped"]:
for dtype in ["fp16", "bf16"]:
for has_causalmask in [True, False]:
for has_bias, has_bias_grad in [[True, False], [True, True], [False, False]]:
for has_bias, has_bias_grad in [
[True, False],
[True, True],
[False, False],
]:
for has_dropout in [True, False]:
for max_k in headdims:
fname = FMHA_BACKWARD_INSTANCE_FNAME.format(
mode=mode,
dtype_str=dtype,
has_or_no_causalmask_str=BOOL_MAP_CAUSALMASK[has_causalmask],
has_or_no_causalmask_str=BOOL_MAP_CAUSALMASK[
has_causalmask
],
has_or_no_bias_str=BOOL_MAP_BIAS[has_bias],
has_or_no_biasgrad_str=BOOL_MAP_BIASGRAD[has_bias_grad],
has_or_no_dropout_str=BOOL_MAP_DROPOUT[has_dropout],
max_k_str=INT_MAP_MAX_K[max_k],
)
backward_instance_inc = FMHA_BACKWARD_INSTANCE_TEMPLATE_INC.format(
mode=mode,
dtype_file=TYPE_FNAME_MAP[dtype],
backward_instance_inc = (
FMHA_BACKWARD_INSTANCE_TEMPLATE_INC.format(
mode=mode,
dtype_file=TYPE_FNAME_MAP[dtype],
)
)
backward_instance = FMHA_BACKWARD_INSTANCE_TEMPLATE.format(
extern="",
Expand All @@ -282,7 +311,11 @@ def create_backward_instances(instance_dir: Path, headdims: List) -> None:
max_k=max_k,
cap_mode=MODE_NAME_MAP[mode],
)
(instance_dir / fname).write_text(FMHA_COPYRIGHT_HEADER + backward_instance_inc + backward_instance)
(instance_dir / fname).write_text(
FMHA_COPYRIGHT_HEADER
+ backward_instance_inc
+ backward_instance
)


def create_backward_instances_ref(instance_dir: Path, headdims: List) -> None:
Expand All @@ -297,23 +330,29 @@ def create_backward_instances_ref(instance_dir: Path, headdims: List) -> None:
mode=mode,
dtype_file=TYPE_FNAME_MAP[dtype],
)
with open(ref_fname, 'a') as file:
with open(ref_fname, "a") as file:
file.write(FMHA_COPYRIGHT_HEADER)
file.write(backward_instance_inc)
for max_k in headdims:
for has_bias, has_bias_grad in [[True, False], [True, True], [False, False]]:
for has_bias, has_bias_grad in [
[True, False],
[True, True],
[False, False],
]:
for has_dropout in [True, False]:
for has_causalmask in [True, False]:
backward_instance = FMHA_BACKWARD_INSTANCE_TEMPLATE.format(
extern="extern ",
mode=mode,
dtype=TYPE_CTYPE_MAP[dtype],
has_causalmask=BOOL_MAP[has_causalmask],
has_bias=BOOL_MAP[has_bias],
has_bias_grad=BOOL_MAP[has_bias_grad],
has_dropout=BOOL_MAP[has_dropout],
max_k=max_k,
cap_mode=MODE_NAME_MAP[mode],
backward_instance = (
FMHA_BACKWARD_INSTANCE_TEMPLATE.format(
extern="extern ",
mode=mode,
dtype=TYPE_CTYPE_MAP[dtype],
has_causalmask=BOOL_MAP[has_causalmask],
has_bias=BOOL_MAP[has_bias],
has_bias_grad=BOOL_MAP[has_bias_grad],
has_dropout=BOOL_MAP[has_dropout],
max_k=max_k,
cap_mode=MODE_NAME_MAP[mode],
)
)
file.write(backward_instance)

Expand All @@ -334,7 +373,7 @@ def create_backward_instances_ref(instance_dir: Path, headdims: List) -> None:
output_dir = Path(this_dir) / "instances"
output_dir.mkdir(parents=True, exist_ok=True)

## remove existing files in the directory
# remove existing files in the directory
files = os.listdir(output_dir)
for ff in files:
file_path = os.path.join(output_dir, ff)
Expand Down
4 changes: 2 additions & 2 deletions xformers/ops/fmha/ck.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ class BwOp(AttentionBwOpBase):
OPERATOR = get_operator("xformers", "efficient_attention_backward_ck")
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
SUPPORTED_MAX_K = 256
SUPPORTED_MAX_K = 256
SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
type(None),
torch.Tensor,
Expand All @@ -369,7 +369,7 @@ class BwOp(AttentionBwOpBase):
32, # 64x64 kernel
64,
128, # 64x128/128x128 kernel
256,
256,
]

@classmethod
Expand Down
Loading

0 comments on commit cee0980

Please sign in to comment.