Updated test to include CK/AITER V2/V3 test in single backend case#454
Updated test to include CK/AITER V2/V3 test in single backend case#454
Conversation
| # FusedAttention backend | ||
| if fused_attn_supported: | ||
| if len(fused_attn_backends) == 1: | ||
| if len(fused_attn_backends) == 1 and FusedAttnBackend["CK"] not in fused_attn_backends: |
There was a problem hiding this comment.
FusedAnntBackend["CK"] should be guarded - it is invalid key on NV platform.
| ) | ||
| if len(fused_attn_backends) == 2: | ||
| # We can consider the CK backend as being two, since we have V2/V3 kernels | ||
| elif len(fused_attn_backends) == 1: |
There was a problem hiding this comment.
To avoid code duplication it should rather be outside of len(fused_atn_backends) but under separate 'if IS_HIP_EXTENSION and FusedAnndBackends["CK"] in fused_attention_backends'
There was a problem hiding this comment.
Doesn't seem so. The code is still duplicated for len(fused_attn_backends) of 1 and 2
|
How many previously skipped test cases are run with that change? It only allows running configs that are not supported by AOTriton but supported by Unfused or FA. |
Good point, I've updated the skip conditions to double-count CK directly. |
| is_training, | ||
| ) | ||
| if len(fused_attn_backends) == 2: | ||
| # We can consider the CK backend as being two, since we have V2/V3 kernels |
There was a problem hiding this comment.
This code duplicates the same addition under len(fused_attn_backend) == 2:
Running of V2 is independent from len(fised_attn_backend), but is controlled by FusedAttnBackend["CK"] presence so it should be outside of that if.
There was a problem hiding this comment.
I've reorganized the test a bit now, thanks.
| # Double-count the CK backend since we want to compare V2/V3 kernels | ||
| if ( | ||
| len(fused_attn_backends) + | ||
| int(IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends) + |
There was a problem hiding this comment.
Why using int? Because FusedAttnBackend gives two variants its presence is enough to continue so add 'not (IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends)' to original condition
There was a problem hiding this comment.
The int expression double-counts CK as a backend, so if fused_attn_backends=FusedAttnBackend["CK"] then the int(...) == 1. It seemed like applying this to the condition explained why the skip was being avoided in a more literal way (i.e. adjusting the count of the number of backends directly, rather than an excluding condition).
I'm happy to adjust to your approach though.
| torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols) | ||
| for i, _ in enumerate(fused_attn_bwd): | ||
| torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols) | ||
| if ( |
There was a problem hiding this comment.
just "if has_ck_backend" and you can remove V/2/V3 comparison from len(fused_attn_backend) == 2 section
| RoPE, | ||
| is_training, | ||
| ) | ||
| if has_ck_backend: |
There was a problem hiding this comment.
apply the same logic as in test_dot_product_attention
Description
Modifies pytorch FA tests to essentially treat the CK backend as two virtual backends, thus directly comparing V2/V3 implementations even when no other backend is available.
Note that sometimes V3 will fallback to V2, so it's not a rigorous test and doesn't substitute having another backend to compare against, but it enables more tests that would otherwise be skipped.
Fixes https://github.com/ROCm/frameworks-internal/issues/15114
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: