2121from ..utils import flat_product , multi_gpu_test
2222
2323
24+ class Matches (NamedTuple ):
25+ attention_fusion : int
26+ allreduce_fusion : int = 0
27+ sequence_parallel : int = 0
28+ async_tp : int = 0
29+
30+
2431class ModelBackendTestCase (NamedTuple ):
2532 model_name : str
2633 model_kwargs : dict [str , Any ]
2734 backend : _Backend
28- attention_fusions : int
29- allreduce_fusions : int | None = None
35+ matches : Matches
3036
3137
3238MODELS_FP8 : list [ModelBackendTestCase ] = []
@@ -40,15 +46,23 @@ class ModelBackendTestCase(NamedTuple):
4046 model_name = "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8" ,
4147 model_kwargs = dict (max_model_len = 1024 ),
4248 backend = _Backend .TRITON_ATTN ,
43- attention_fusions = 32 ,
44- allreduce_fusions = 65 ,
49+ matches = Matches (
50+ attention_fusion = 32 ,
51+ allreduce_fusion = 65 ,
52+ sequence_parallel = 65 ,
53+ async_tp = 128 ,
54+ ),
4555 ),
4656 ModelBackendTestCase (
4757 model_name = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" ,
4858 model_kwargs = dict (max_model_len = 1024 , kv_cache_dtype = "fp8" ),
4959 backend = _Backend .FLASHINFER ,
50- attention_fusions = 48 ,
51- allreduce_fusions = 96 ,
60+ matches = Matches (
61+ attention_fusion = 48 ,
62+ allreduce_fusion = 96 ,
63+ sequence_parallel = 96 ,
64+ async_tp = 190 ,
65+ ),
5266 ),
5367 ]
5468
@@ -57,8 +71,12 @@ class ModelBackendTestCase(NamedTuple):
5771 model_name = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4" ,
5872 model_kwargs = dict (max_model_len = 1024 , kv_cache_dtype = "fp8" ),
5973 backend = _Backend .FLASHINFER ,
60- attention_fusions = 48 ,
61- allreduce_fusions = 96 ,
74+ matches = Matches (
75+ attention_fusion = 48 ,
76+ allreduce_fusion = 96 ,
77+ sequence_parallel = 96 ,
78+ async_tp = 190 ,
79+ ),
6280 ),
6381 ]
6482
@@ -68,8 +86,12 @@ class ModelBackendTestCase(NamedTuple):
6886 model_name = "meta-llama/Llama-3.1-8B-Instruct" ,
6987 model_kwargs = dict (max_model_len = 1024 ),
7088 backend = _Backend .TRITON_ATTN ,
71- attention_fusions = 0 ,
72- allreduce_fusions = 65 ,
89+ matches = Matches (
90+ attention_fusion = 32 ,
91+ allreduce_fusion = 65 ,
92+ sequence_parallel = 65 ,
93+ async_tp = 128 ,
94+ ),
7395 ),
7496 ]
7597
@@ -79,19 +101,19 @@ class ModelBackendTestCase(NamedTuple):
79101 model_name = "amd/Llama-3.1-8B-Instruct-FP8-KV" ,
80102 model_kwargs = dict (max_model_len = 1024 ),
81103 backend = _Backend .TRITON_ATTN ,
82- attention_fusions = 32 ,
104+ matches = Matches ( attention_fusion = 32 ) ,
83105 ),
84106 ModelBackendTestCase (
85107 model_name = "amd/Llama-3.1-8B-Instruct-FP8-KV" ,
86108 model_kwargs = dict (max_model_len = 1024 ),
87109 backend = _Backend .ROCM_ATTN ,
88- attention_fusions = 32 ,
110+ matches = Matches ( attention_fusion = 32 ) ,
89111 ),
90112 ModelBackendTestCase (
91113 model_name = "amd/Llama-3.1-8B-Instruct-FP8-KV" ,
92114 model_kwargs = dict (max_model_len = 1024 ),
93115 backend = _Backend .ROCM_AITER_UNIFIED_ATTN ,
94- attention_fusions = 32 ,
116+ matches = Matches ( attention_fusion = 32 ) ,
95117 ),
96118 ]
97119
@@ -100,8 +122,7 @@ class ModelBackendTestCase(NamedTuple):
100122
101123
102124@pytest .mark .parametrize (
103- "model_name, model_kwargs, backend, "
104- "attention_fusions, allreduce_fusions, custom_ops" ,
125+ "model_name, model_kwargs, backend, matches, custom_ops" ,
105126 # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
106127 list (flat_product (MODELS_FP8 , CUSTOM_OPS_FP8 ))
107128 # quant_fp4 only has the custom impl
@@ -112,8 +133,7 @@ def test_attn_quant(
112133 model_name : str ,
113134 model_kwargs : dict [str , Any ],
114135 backend : _Backend ,
115- attention_fusions : int ,
116- allreduce_fusions : int ,
136+ matches : Matches ,
117137 custom_ops : str ,
118138 inductor_graph_partition : bool ,
119139 caplog_mp_spawn ,
@@ -160,12 +180,12 @@ def test_attn_quant(
160180 with caplog_mp_spawn (logging .DEBUG ) as log_holder :
161181 run_model (compilation_config , model_name , ** model_kwargs )
162182
163- matches = re .findall (
183+ log_matches = re .findall (
164184 r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes" ,
165185 log_holder .text ,
166186 )
167- assert len (matches ) == 1 , log_holder .text
168- assert int (matches [0 ]) == attention_fusions
187+ assert len (log_matches ) == 1 , log_holder .text
188+ assert int (log_matches [0 ]) == matches . attention_fusion
169189
170190
171191# TODO(luka) test both in nightly
@@ -179,8 +199,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
179199
180200@multi_gpu_test (num_gpus = 2 )
181201@pytest .mark .parametrize (
182- "model_name, model_kwargs, backend, "
183- "attention_fusions, allreduce_fusions, custom_ops" ,
202+ "model_name, model_kwargs, backend, matches, custom_ops" ,
184203 # Toggle RMSNorm and QuantFP8 for FP8 models
185204 list (
186205 flat_product (
@@ -201,8 +220,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
201220 model_name : str ,
202221 model_kwargs : dict ,
203222 backend : _Backend ,
204- attention_fusions : int ,
205- allreduce_fusions : int ,
223+ matches : Matches ,
206224 custom_ops : str ,
207225 inductor_graph_partition : bool ,
208226 caplog_mp_spawn ,
@@ -250,23 +268,23 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
250268 run_model (
251269 compilation_config , model_name , tensor_parallel_size = 2 , ** model_kwargs
252270 )
253- matches = re .findall (
271+ log_matches = re .findall (
254272 r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes" ,
255273 log_holder .text ,
256274 )
257- assert len (matches ) == 2 , log_holder .text
275+ assert len (log_matches ) == 2 , log_holder .text
258276
259- assert int (matches [0 ]) == attention_fusions
260- assert int (matches [1 ]) == attention_fusions
277+ assert int (log_matches [0 ]) == matches . attention_fusion
278+ assert int (log_matches [1 ]) == matches . attention_fusion
261279
262- matches = re .findall (
280+ log_matches = re .findall (
263281 r"collective_fusion.py:\d+] Replaced (\d+) patterns" ,
264282 log_holder .text ,
265283 )
266- assert len (matches ) == 2 , log_holder .text
284+ assert len (log_matches ) == 2 , log_holder .text
267285
268- assert int (matches [0 ]) == allreduce_fusions
269- assert int (matches [1 ]) == allreduce_fusions
286+ assert int (log_matches [0 ]) == matches . allreduce_fusion
287+ assert int (log_matches [1 ]) == matches . allreduce_fusion
270288
271289
272290def run_model (compile_config : int | CompilationConfig , model : str , ** model_kwargs ):
0 commit comments