Skip to content

Commit af31964

Browse files
committed
refactor e2e test to use Matches object
Signed-off-by: ProExpertProg <lgovedic@redhat.com>
1 parent 2b5cfc8 commit af31964

File tree

1 file changed

+50
-32
lines changed

1 file changed

+50
-32
lines changed

tests/compile/test_fusions_e2e.py

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,18 @@
2121
from ..utils import flat_product, multi_gpu_test
2222

2323

24+
class Matches(NamedTuple):
25+
attention_fusion: int
26+
allreduce_fusion: int = 0
27+
sequence_parallel: int = 0
28+
async_tp: int = 0
29+
30+
2431
class ModelBackendTestCase(NamedTuple):
2532
model_name: str
2633
model_kwargs: dict[str, Any]
2734
backend: _Backend
28-
attention_fusions: int
29-
allreduce_fusions: int | None = None
35+
matches: Matches
3036

3137

3238
MODELS_FP8: list[ModelBackendTestCase] = []
@@ -40,15 +46,23 @@ class ModelBackendTestCase(NamedTuple):
4046
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
4147
model_kwargs=dict(max_model_len=1024),
4248
backend=_Backend.TRITON_ATTN,
43-
attention_fusions=32,
44-
allreduce_fusions=65,
49+
matches=Matches(
50+
attention_fusion=32,
51+
allreduce_fusion=65,
52+
sequence_parallel=65,
53+
async_tp=128,
54+
),
4555
),
4656
ModelBackendTestCase(
4757
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
4858
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
4959
backend=_Backend.FLASHINFER,
50-
attention_fusions=48,
51-
allreduce_fusions=96,
60+
matches=Matches(
61+
attention_fusion=48,
62+
allreduce_fusion=96,
63+
sequence_parallel=96,
64+
async_tp=190,
65+
),
5266
),
5367
]
5468

@@ -57,8 +71,12 @@ class ModelBackendTestCase(NamedTuple):
5771
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
5872
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
5973
backend=_Backend.FLASHINFER,
60-
attention_fusions=48,
61-
allreduce_fusions=96,
74+
matches=Matches(
75+
attention_fusion=48,
76+
allreduce_fusion=96,
77+
sequence_parallel=96,
78+
async_tp=190,
79+
),
6280
),
6381
]
6482

@@ -68,8 +86,12 @@ class ModelBackendTestCase(NamedTuple):
6886
model_name="meta-llama/Llama-3.1-8B-Instruct",
6987
model_kwargs=dict(max_model_len=1024),
7088
backend=_Backend.TRITON_ATTN,
71-
attention_fusions=0,
72-
allreduce_fusions=65,
89+
matches=Matches(
90+
attention_fusion=32,
91+
allreduce_fusion=65,
92+
sequence_parallel=65,
93+
async_tp=128,
94+
),
7395
),
7496
]
7597

@@ -79,19 +101,19 @@ class ModelBackendTestCase(NamedTuple):
79101
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
80102
model_kwargs=dict(max_model_len=1024),
81103
backend=_Backend.TRITON_ATTN,
82-
attention_fusions=32,
104+
matches=Matches(attention_fusion=32),
83105
),
84106
ModelBackendTestCase(
85107
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
86108
model_kwargs=dict(max_model_len=1024),
87109
backend=_Backend.ROCM_ATTN,
88-
attention_fusions=32,
110+
matches=Matches(attention_fusion=32),
89111
),
90112
ModelBackendTestCase(
91113
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
92114
model_kwargs=dict(max_model_len=1024),
93115
backend=_Backend.ROCM_AITER_UNIFIED_ATTN,
94-
attention_fusions=32,
116+
matches=Matches(attention_fusion=32),
95117
),
96118
]
97119

@@ -100,8 +122,7 @@ class ModelBackendTestCase(NamedTuple):
100122

101123

102124
@pytest.mark.parametrize(
103-
"model_name, model_kwargs, backend, "
104-
"attention_fusions, allreduce_fusions, custom_ops",
125+
"model_name, model_kwargs, backend, matches, custom_ops",
105126
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
106127
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
107128
# quant_fp4 only has the custom impl
@@ -112,8 +133,7 @@ def test_attn_quant(
112133
model_name: str,
113134
model_kwargs: dict[str, Any],
114135
backend: _Backend,
115-
attention_fusions: int,
116-
allreduce_fusions: int,
136+
matches: Matches,
117137
custom_ops: str,
118138
inductor_graph_partition: bool,
119139
caplog_mp_spawn,
@@ -160,12 +180,12 @@ def test_attn_quant(
160180
with caplog_mp_spawn(logging.DEBUG) as log_holder:
161181
run_model(compilation_config, model_name, **model_kwargs)
162182

163-
matches = re.findall(
183+
log_matches = re.findall(
164184
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
165185
log_holder.text,
166186
)
167-
assert len(matches) == 1, log_holder.text
168-
assert int(matches[0]) == attention_fusions
187+
assert len(log_matches) == 1, log_holder.text
188+
assert int(log_matches[0]) == matches.attention_fusion
169189

170190

171191
# TODO(luka) test both in nightly
@@ -179,8 +199,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
179199

180200
@multi_gpu_test(num_gpus=2)
181201
@pytest.mark.parametrize(
182-
"model_name, model_kwargs, backend, "
183-
"attention_fusions, allreduce_fusions, custom_ops",
202+
"model_name, model_kwargs, backend, matches, custom_ops",
184203
# Toggle RMSNorm and QuantFP8 for FP8 models
185204
list(
186205
flat_product(
@@ -201,8 +220,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
201220
model_name: str,
202221
model_kwargs: dict,
203222
backend: _Backend,
204-
attention_fusions: int,
205-
allreduce_fusions: int,
223+
matches: Matches,
206224
custom_ops: str,
207225
inductor_graph_partition: bool,
208226
caplog_mp_spawn,
@@ -250,23 +268,23 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
250268
run_model(
251269
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
252270
)
253-
matches = re.findall(
271+
log_matches = re.findall(
254272
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
255273
log_holder.text,
256274
)
257-
assert len(matches) == 2, log_holder.text
275+
assert len(log_matches) == 2, log_holder.text
258276

259-
assert int(matches[0]) == attention_fusions
260-
assert int(matches[1]) == attention_fusions
277+
assert int(log_matches[0]) == matches.attention_fusion
278+
assert int(log_matches[1]) == matches.attention_fusion
261279

262-
matches = re.findall(
280+
log_matches = re.findall(
263281
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
264282
log_holder.text,
265283
)
266-
assert len(matches) == 2, log_holder.text
284+
assert len(log_matches) == 2, log_holder.text
267285

268-
assert int(matches[0]) == allreduce_fusions
269-
assert int(matches[1]) == allreduce_fusions
286+
assert int(log_matches[0]) == matches.allreduce_fusion
287+
assert int(log_matches[1]) == matches.allreduce_fusion
270288

271289

272290
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):

0 commit comments

Comments
 (0)