From f7b28c52a9b00aed07819266a4d54b899e92eb3f Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Fri, 16 Aug 2024 19:45:57 +0000 Subject: [PATCH 1/4] apply black --- setup.py | 2 +- .../attention/hip_fmha/generate_instances.py | 175 +++++++++++------- xformers/ops/fmha/ck.py | 4 +- 3 files changed, 110 insertions(+), 71 deletions(-) diff --git a/setup.py b/setup.py index f648706e2b..abadb4a17f 100644 --- a/setup.py +++ b/setup.py @@ -451,7 +451,7 @@ def get_extensions(): "-Werror", "-Woverloaded-virtual", "-mllvm", - "-enable-post-misched=0" + "-enable-post-misched=0", ] + generator_flag + cc_flag, diff --git a/xformers/csrc/attention/hip_fmha/generate_instances.py b/xformers/csrc/attention/hip_fmha/generate_instances.py index fc27bcc545..bfbe5f345a 100644 --- a/xformers/csrc/attention/hip_fmha/generate_instances.py +++ b/xformers/csrc/attention/hip_fmha/generate_instances.py @@ -35,8 +35,10 @@ {max_k}>({cap_mode}ForwardParams& param, hipStream_t stream); """ -FMHA_INFER_INSTANCE_FNAME = "fmha_{mode}_infer_{dtype_str}_{has_or_no_causalmask_str}_"\ - "{has_or_no_bias_str}_{has_or_no_dropout_str}_{max_k_str}.cpp" +FMHA_INFER_INSTANCE_FNAME = ( + "fmha_{mode}_infer_{dtype_str}_{has_or_no_causalmask_str}_" + "{has_or_no_bias_str}_{has_or_no_dropout_str}_{max_k_str}.cpp" +) FMHA_FORWARD_INSTANCE_TEMPLATE_INC = """ #include @@ -52,8 +54,10 @@ {max_k}>({cap_mode}ForwardParams& param, hipStream_t stream); """ -FMHA_FORWARD_INSTANCE_FNAME = "fmha_{mode}_forward_{dtype_str}_{has_or_no_causalmask_str}_"\ - "{has_or_no_bias_str}_{has_or_no_dropout_str}_{max_k_str}.cpp" +FMHA_FORWARD_INSTANCE_FNAME = ( + "fmha_{mode}_forward_{dtype_str}_{has_or_no_causalmask_str}_" + "{has_or_no_bias_str}_{has_or_no_dropout_str}_{max_k_str}.cpp" +) FMHA_BACKWARD_INSTANCE_TEMPLATE_INC = """ #include @@ -70,56 +74,55 @@ {max_k}>({cap_mode}BackwardParams& param, hipStream_t stream); """ -FMHA_BACKWARD_INSTANCE_FNAME = "fmha_{mode}_backward_{dtype_str}_{has_or_no_causalmask_str}_"\ - "{has_or_no_bias_str}_{has_or_no_biasgrad_str}_{has_or_no_dropout_str}_{max_k_str}.cpp" +FMHA_BACKWARD_INSTANCE_FNAME = ( + "fmha_{mode}_backward_{dtype_str}_{has_or_no_causalmask_str}_" + "{has_or_no_bias_str}_{has_or_no_biasgrad_str}_{has_or_no_dropout_str}_{max_k_str}.cpp" +) FMHA_INSTANCE_REF_FNAME = "instances/fmha_{mode}_{function}_{dtype}_instances_ref.h" -BOOL_MAP = { - True : "true", - False : "false" -} +BOOL_MAP = {True: "true", False: "false"} BOOL_MAP_CAUSALMASK = { - True : "has_causalmask", - False : "no_causalmask", + True: "has_causalmask", + False: "no_causalmask", } BOOL_MAP_BIAS = { - True : "has_bias", - False : "no_bias", + True: "has_bias", + False: "no_bias", } BOOL_MAP_BIASGRAD = { - True : "has_biasgrad", - False : "no_biasgrad", + True: "has_biasgrad", + False: "no_biasgrad", } BOOL_MAP_DROPOUT = { - True : "has_dropout", - False : "no_dropout", + True: "has_dropout", + False: "no_dropout", } INT_MAP_MAX_K = { - 32 : "maxk_32", - 64 : "maxk_64", - 128 : "maxk_128", - 256 : "maxk_256", + 32: "maxk_32", + 64: "maxk_64", + 128: "maxk_128", + 256: "maxk_256", } TYPE_CTYPE_MAP = { - "fp16" : "ck_tile::fp16_t", - "bf16" : "ck_tile::bf16_t", + "fp16": "ck_tile::fp16_t", + "bf16": "ck_tile::bf16_t", } TYPE_FNAME_MAP = { - "fp16" : "half", - "bf16" : "bfloat16", + "fp16": "half", + "bf16": "bfloat16", } MODE_NAME_MAP = { - "batched" : "Batched", - "grouped" : "Grouped", + "batched": "Batched", + "grouped": "Grouped", } @@ -133,14 +136,18 @@ def create_infer_instances(instance_dir: Path, headdims: List) -> None: fname = FMHA_INFER_INSTANCE_FNAME.format( mode=mode, dtype_str=dtype, - has_or_no_causalmask_str=BOOL_MAP_CAUSALMASK[has_causalmask], + has_or_no_causalmask_str=BOOL_MAP_CAUSALMASK[ + has_causalmask + ], has_or_no_bias_str=BOOL_MAP_BIAS[has_bias], has_or_no_dropout_str=BOOL_MAP_DROPOUT[has_dropout], max_k_str=INT_MAP_MAX_K[max_k], ) - infer_instance_inc = FMHA_INFER_INSTANCE_TEMPLATE_INC.format( - mode=mode, - dtype_file=TYPE_FNAME_MAP[dtype], + infer_instance_inc = ( + FMHA_INFER_INSTANCE_TEMPLATE_INC.format( + mode=mode, + dtype_file=TYPE_FNAME_MAP[dtype], + ) ) infer_instance = FMHA_INFER_INSTANCE_TEMPLATE.format( extern="", @@ -152,7 +159,11 @@ def create_infer_instances(instance_dir: Path, headdims: List) -> None: max_k=max_k, cap_mode=MODE_NAME_MAP[mode], ) - (instance_dir / fname).write_text(FMHA_COPYRIGHT_HEADER + infer_instance_inc + infer_instance) + (instance_dir / fname).write_text( + FMHA_COPYRIGHT_HEADER + + infer_instance_inc + + infer_instance + ) def create_infer_instances_ref(instance_dir: Path, headdims: List) -> None: @@ -167,7 +178,7 @@ def create_infer_instances_ref(instance_dir: Path, headdims: List) -> None: mode=mode, dtype_file=TYPE_FNAME_MAP[dtype], ) - with open(ref_fname, 'a') as file: + with open(ref_fname, "a") as file: file.write(FMHA_COPYRIGHT_HEADER) file.write(infer_instance_inc) for max_k in headdims: @@ -197,15 +208,19 @@ def create_forward_instances(instance_dir: Path, headdims: List) -> None: fname = FMHA_FORWARD_INSTANCE_FNAME.format( mode=mode, dtype_str=dtype, - has_or_no_causalmask_str=BOOL_MAP_CAUSALMASK[has_causalmask], + has_or_no_causalmask_str=BOOL_MAP_CAUSALMASK[ + has_causalmask + ], has_or_no_bias_str=BOOL_MAP_BIAS[has_bias], has_or_no_dropout_str=BOOL_MAP_DROPOUT[has_dropout], max_k_str=INT_MAP_MAX_K[max_k], ) - forward_instance_inc = FMHA_FORWARD_INSTANCE_TEMPLATE_INC.format( - mode=mode, - dtype_file=TYPE_FNAME_MAP[dtype], - ) + forward_instance_inc = ( + FMHA_FORWARD_INSTANCE_TEMPLATE_INC.format( + mode=mode, + dtype_file=TYPE_FNAME_MAP[dtype], + ) + ) forward_instance = FMHA_FORWARD_INSTANCE_TEMPLATE.format( extern="", mode=mode, @@ -216,7 +231,11 @@ def create_forward_instances(instance_dir: Path, headdims: List) -> None: max_k=max_k, cap_mode=MODE_NAME_MAP[mode], ) - (instance_dir / fname).write_text(FMHA_COPYRIGHT_HEADER + forward_instance_inc + forward_instance) + (instance_dir / fname).write_text( + FMHA_COPYRIGHT_HEADER + + forward_instance_inc + + forward_instance + ) def create_forward_instances_ref(instance_dir: Path, headdims: List) -> None: @@ -231,22 +250,24 @@ def create_forward_instances_ref(instance_dir: Path, headdims: List) -> None: mode=mode, dtype_file=TYPE_FNAME_MAP[dtype], ) - with open(ref_fname, 'a') as file: + with open(ref_fname, "a") as file: file.write(FMHA_COPYRIGHT_HEADER) file.write(forward_instance_inc) for max_k in headdims: for has_bias in [True, False]: for has_dropout in [True, False]: for has_causalmask in [True, False]: - forward_instance = FMHA_FORWARD_INSTANCE_TEMPLATE.format( - extern="extern ", - mode=mode, - dtype=TYPE_CTYPE_MAP[dtype], - has_causalmask=BOOL_MAP[has_causalmask], - has_bias=BOOL_MAP[has_bias], - has_dropout=BOOL_MAP[has_dropout], - max_k=max_k, - cap_mode=MODE_NAME_MAP[mode], + forward_instance = ( + FMHA_FORWARD_INSTANCE_TEMPLATE.format( + extern="extern ", + mode=mode, + dtype=TYPE_CTYPE_MAP[dtype], + has_causalmask=BOOL_MAP[has_causalmask], + has_bias=BOOL_MAP[has_bias], + has_dropout=BOOL_MAP[has_dropout], + max_k=max_k, + cap_mode=MODE_NAME_MAP[mode], + ) ) file.write(forward_instance) @@ -255,21 +276,29 @@ def create_backward_instances(instance_dir: Path, headdims: List) -> None: for mode in ["batched", "grouped"]: for dtype in ["fp16", "bf16"]: for has_causalmask in [True, False]: - for has_bias, has_bias_grad in [[True, False], [True, True], [False, False]]: + for has_bias, has_bias_grad in [ + [True, False], + [True, True], + [False, False], + ]: for has_dropout in [True, False]: for max_k in headdims: fname = FMHA_BACKWARD_INSTANCE_FNAME.format( mode=mode, dtype_str=dtype, - has_or_no_causalmask_str=BOOL_MAP_CAUSALMASK[has_causalmask], + has_or_no_causalmask_str=BOOL_MAP_CAUSALMASK[ + has_causalmask + ], has_or_no_bias_str=BOOL_MAP_BIAS[has_bias], has_or_no_biasgrad_str=BOOL_MAP_BIASGRAD[has_bias_grad], has_or_no_dropout_str=BOOL_MAP_DROPOUT[has_dropout], max_k_str=INT_MAP_MAX_K[max_k], ) - backward_instance_inc = FMHA_BACKWARD_INSTANCE_TEMPLATE_INC.format( - mode=mode, - dtype_file=TYPE_FNAME_MAP[dtype], + backward_instance_inc = ( + FMHA_BACKWARD_INSTANCE_TEMPLATE_INC.format( + mode=mode, + dtype_file=TYPE_FNAME_MAP[dtype], + ) ) backward_instance = FMHA_BACKWARD_INSTANCE_TEMPLATE.format( extern="", @@ -282,7 +311,11 @@ def create_backward_instances(instance_dir: Path, headdims: List) -> None: max_k=max_k, cap_mode=MODE_NAME_MAP[mode], ) - (instance_dir / fname).write_text(FMHA_COPYRIGHT_HEADER + backward_instance_inc + backward_instance) + (instance_dir / fname).write_text( + FMHA_COPYRIGHT_HEADER + + backward_instance_inc + + backward_instance + ) def create_backward_instances_ref(instance_dir: Path, headdims: List) -> None: @@ -297,23 +330,29 @@ def create_backward_instances_ref(instance_dir: Path, headdims: List) -> None: mode=mode, dtype_file=TYPE_FNAME_MAP[dtype], ) - with open(ref_fname, 'a') as file: + with open(ref_fname, "a") as file: file.write(FMHA_COPYRIGHT_HEADER) file.write(backward_instance_inc) for max_k in headdims: - for has_bias, has_bias_grad in [[True, False], [True, True], [False, False]]: + for has_bias, has_bias_grad in [ + [True, False], + [True, True], + [False, False], + ]: for has_dropout in [True, False]: for has_causalmask in [True, False]: - backward_instance = FMHA_BACKWARD_INSTANCE_TEMPLATE.format( - extern="extern ", - mode=mode, - dtype=TYPE_CTYPE_MAP[dtype], - has_causalmask=BOOL_MAP[has_causalmask], - has_bias=BOOL_MAP[has_bias], - has_bias_grad=BOOL_MAP[has_bias_grad], - has_dropout=BOOL_MAP[has_dropout], - max_k=max_k, - cap_mode=MODE_NAME_MAP[mode], + backward_instance = ( + FMHA_BACKWARD_INSTANCE_TEMPLATE.format( + extern="extern ", + mode=mode, + dtype=TYPE_CTYPE_MAP[dtype], + has_causalmask=BOOL_MAP[has_causalmask], + has_bias=BOOL_MAP[has_bias], + has_bias_grad=BOOL_MAP[has_bias_grad], + has_dropout=BOOL_MAP[has_dropout], + max_k=max_k, + cap_mode=MODE_NAME_MAP[mode], + ) ) file.write(backward_instance) diff --git a/xformers/ops/fmha/ck.py b/xformers/ops/fmha/ck.py index 47ad90d2f9..889eeb4462 100644 --- a/xformers/ops/fmha/ck.py +++ b/xformers/ops/fmha/ck.py @@ -344,7 +344,7 @@ class BwOp(AttentionBwOpBase): OPERATOR = get_operator("xformers", "efficient_attention_backward_ck") SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES - SUPPORTED_MAX_K = 256 + SUPPORTED_MAX_K = 256 SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = ( type(None), torch.Tensor, @@ -369,7 +369,7 @@ class BwOp(AttentionBwOpBase): 32, # 64x64 kernel 64, 128, # 64x128/128x128 kernel - 256, + 256, ] @classmethod From fd82f20b6c7a3b2f30856d48575065e45cd10028 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Fri, 16 Aug 2024 19:50:50 +0000 Subject: [PATCH 2/4] apply flake8 --- xformers/csrc/attention/hip_fmha/generate_instances.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xformers/csrc/attention/hip_fmha/generate_instances.py b/xformers/csrc/attention/hip_fmha/generate_instances.py index bfbe5f345a..d9a2763509 100644 --- a/xformers/csrc/attention/hip_fmha/generate_instances.py +++ b/xformers/csrc/attention/hip_fmha/generate_instances.py @@ -373,7 +373,7 @@ def create_backward_instances_ref(instance_dir: Path, headdims: List) -> None: output_dir = Path(this_dir) / "instances" output_dir.mkdir(parents=True, exist_ok=True) - ## remove existing files in the directory + # remove existing files in the directory files = os.listdir(output_dir) for ff in files: file_path = os.path.join(output_dir, ff) From 7d21800f684e4d654cdec49e10ed545d03a598f9 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Fri, 16 Aug 2024 20:43:02 +0000 Subject: [PATCH 3/4] fix mypy --- tests/test_mem_eff_attention.py | 6 +++--- xformers/attn_bias_utils.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index ed6d6a696a..ad71241eda 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -37,13 +37,13 @@ if torch.cuda.is_available(): compute_capability = torch.cuda.get_device_capability("cuda") sm70_or_better_only = pytest.mark.skipif( - torch.version.cuda and compute_capability < (7, 0), reason="requires sm70+" + torch.version.cuda is not None and compute_capability < (7, 0), reason="requires sm70+" ) sm75_or_better_only = pytest.mark.skipif( - torch.version.cuda and compute_capability < (7, 5), reason="requires sm75+" + torch.version.cuda is not None and compute_capability < (7, 5), reason="requires sm75+" ) sm80_or_better_only = pytest.mark.skipif( - torch.version.cuda and compute_capability < (8, 0), reason="requires sm80+" + torch.version.cuda is not None and compute_capability < (8, 0), reason="requires sm80+" ) skip_if_rocm = pytest.mark.skipif( torch.version.hip is not None, reason="not supported on ROCm" diff --git a/xformers/attn_bias_utils.py b/xformers/attn_bias_utils.py index 224302c4f8..fb8d8207f2 100644 --- a/xformers/attn_bias_utils.py +++ b/xformers/attn_bias_utils.py @@ -39,7 +39,7 @@ def create_attn_bias( dtype, requires_grad: bool, fmt: str, - op: Type[AttentionOpBase], + op: Optional[Type[AttentionOpBase]] = None, page_size: Optional[int] = None, ): if bias_type is None or isinstance(None, bias_type): @@ -59,7 +59,7 @@ def create_attn_bias( * 3 ) attn_bias = attn_bias.expand(batch_size, num_heads, q_len, kv_len) - elif issubclass(op, fmha.triton_splitk.FwOp): + elif op is not None and issubclass(op, fmha.triton_splitk.FwOp): attn_bias = ( torch.randn( (batch_size, num_heads_groups, num_heads, q_len, kv_len), From d6b64568739952fd95bf4eb172d6fbbdd53964d1 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Fri, 16 Aug 2024 21:05:42 +0000 Subject: [PATCH 4/4] revert disable flash operator on rocm --- xformers/ops/fmha/flash.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 14a8335ec1..49e708dc28 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -607,10 +607,7 @@ class FwOp(AttentionFwOpBase): implementation. """ - if torch.version.hip: - OPERATOR = None - else: - OPERATOR = get_operator("xformers_flash", "flash_fwd") + OPERATOR = get_operator("xformers_flash", "flash_fwd") SUPPORTED_DEVICES: Set[str] = {"cuda"} CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} @@ -812,10 +809,7 @@ def operator_flop( class BwOp(AttentionBwOpBase): __doc__ = FwOp.__doc__ - if torch.version.hip: - OPERATOR = None - else: - OPERATOR = get_operator("xformers_flash", "flash_bwd") + OPERATOR = get_operator("xformers_flash", "flash_bwd") SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES