Skip to content

Commit 865b0bf

Browse files
Simplify operation implementation: remove mode switching, always use global counter
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
1 parent 2c81fbb commit 865b0bf

File tree

2 files changed

+16
-39
lines changed

2 files changed

+16
-39
lines changed

tests/compile/piecewise/test_simple.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,9 @@
1717

1818
# Import shared test operations
1919
from tests.compile.test_operations import (
20-
get_global_counter, reset_global_counter, enable_counting_mode
20+
get_global_counter, reset_global_counter
2121
)
2222

23-
# Enable counting mode for this test's specific behavior
24-
enable_counting_mode()
25-
2623

2724
@support_torch_compile
2825
class SillyModel(nn.Module):
@@ -36,9 +33,8 @@ def __init__(self,
3633

3734
def forward(self, x: torch.Tensor) -> torch.Tensor:
3835
"""
39-
Overall effect:
40-
x += 1
41-
x[0] += 2
36+
Overall effect with unified attention implementation:
37+
input [0., 0.] -> final output [19., 19.]
4238
global_counter += 2
4339
"""
4440
x = x + 1
@@ -107,4 +103,4 @@ def test_simple_piecewise_compile(use_inductor):
107103
batch_descriptor=BatchDescriptor(num_tokens=2, )):
108104
output = model(input)
109105
assert get_global_counter() == 2
110-
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
106+
assert torch.allclose(output.cpu(), torch.tensor([19., 19.]))

tests/compile/test_operations.py

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
running tests together.
1010
1111
The main "attention" operation is automatically registered when this module
12-
is imported. Individual test files can access additional functionality
12+
is imported. Individual test files can access the global counter functionality
1313
through helper functions.
1414
"""
1515

@@ -23,53 +23,34 @@
2323
silly_lib = Library("silly", "FRAGMENT")
2424

2525

26-
# Global state for test_simple.py compatibility
26+
# Global counter that all tests can use or ignore
2727
_global_counter = 0
28-
_use_counting_mode = False
2928

3029

3130
def get_global_counter():
32-
"""Get the current global counter value (for test_simple.py)"""
31+
"""Get the current global counter value"""
3332
return _global_counter
3433

3534

3635
def reset_global_counter():
37-
"""Reset the global counter to 0 (for test_simple.py)"""
36+
"""Reset the global counter to 0"""
3837
global _global_counter
3938
_global_counter = 0
4039

4140

42-
def enable_counting_mode():
43-
"""Enable counting mode for test_simple.py"""
44-
global _use_counting_mode
45-
_use_counting_mode = True
46-
reset_global_counter()
47-
48-
49-
def disable_counting_mode():
50-
"""Disable counting mode"""
51-
global _use_counting_mode
52-
_use_counting_mode = False
53-
54-
5541
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
5642
out: torch.Tensor) -> None:
5743
"""
58-
Unified attention implementation that can handle both standard and counting modes.
44+
Unified attention implementation that depends on all inputs and affects the output.
45+
Always increments a global counter that tests can use or ignore.
5946
"""
60-
global _global_counter, _use_counting_mode
47+
global _global_counter
48+
49+
# Always increment the global counter
50+
_global_counter += 1
6151

62-
if _use_counting_mode:
63-
# Counting mode for test_simple.py
64-
_global_counter += 1
65-
print(f"global_counter={_global_counter}")
66-
out.copy_(q)
67-
out[0] += 1
68-
else:
69-
# Standard mode for test_multiple_graphs.py and test_toy_llama.py
70-
out.copy_(q)
71-
out += k
72-
out += v
52+
# Unified implementation that depends on all inputs
53+
out.copy_(q + k + v)
7354

7455

7556
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,

0 commit comments

Comments
 (0)