|  | 
| 9 | 9 | running tests together. | 
| 10 | 10 | 
 | 
| 11 | 11 | The main "attention" operation is automatically registered when this module | 
| 12 |  | -is imported. Individual test files can access additional functionality | 
|  | 12 | +is imported. Individual test files can access the global counter functionality | 
| 13 | 13 | through helper functions. | 
| 14 | 14 | """ | 
| 15 | 15 | 
 | 
|  | 
| 23 | 23 | silly_lib = Library("silly", "FRAGMENT") | 
| 24 | 24 | 
 | 
| 25 | 25 | 
 | 
| 26 |  | -# Global state for test_simple.py compatibility | 
|  | 26 | +# Global counter that all tests can use or ignore | 
| 27 | 27 | _global_counter = 0 | 
| 28 |  | -_use_counting_mode = False | 
| 29 | 28 | 
 | 
| 30 | 29 | 
 | 
| 31 | 30 | def get_global_counter(): | 
| 32 |  | -    """Get the current global counter value (for test_simple.py)""" | 
|  | 31 | +    """Get the current global counter value""" | 
| 33 | 32 |     return _global_counter | 
| 34 | 33 | 
 | 
| 35 | 34 | 
 | 
| 36 | 35 | def reset_global_counter(): | 
| 37 |  | -    """Reset the global counter to 0 (for test_simple.py)""" | 
|  | 36 | +    """Reset the global counter to 0""" | 
| 38 | 37 |     global _global_counter | 
| 39 | 38 |     _global_counter = 0 | 
| 40 | 39 | 
 | 
| 41 | 40 | 
 | 
| 42 |  | -def enable_counting_mode(): | 
| 43 |  | -    """Enable counting mode for test_simple.py""" | 
| 44 |  | -    global _use_counting_mode | 
| 45 |  | -    _use_counting_mode = True | 
| 46 |  | -    reset_global_counter() | 
| 47 |  | - | 
| 48 |  | - | 
| 49 |  | -def disable_counting_mode(): | 
| 50 |  | -    """Disable counting mode""" | 
| 51 |  | -    global _use_counting_mode | 
| 52 |  | -    _use_counting_mode = False | 
| 53 |  | - | 
| 54 |  | - | 
| 55 | 41 | def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, | 
| 56 | 42 |                    out: torch.Tensor) -> None: | 
| 57 | 43 |     """ | 
| 58 |  | -    Unified attention implementation that can handle both standard and counting modes. | 
|  | 44 | +    Unified attention implementation that depends on all inputs and affects the output. | 
|  | 45 | +    Always increments a global counter that tests can use or ignore. | 
| 59 | 46 |     """ | 
| 60 |  | -    global _global_counter, _use_counting_mode | 
|  | 47 | +    global _global_counter | 
|  | 48 | +     | 
|  | 49 | +    # Always increment the global counter | 
|  | 50 | +    _global_counter += 1 | 
| 61 | 51 | 
 | 
| 62 |  | -    if _use_counting_mode: | 
| 63 |  | -        # Counting mode for test_simple.py | 
| 64 |  | -        _global_counter += 1 | 
| 65 |  | -        print(f"global_counter={_global_counter}") | 
| 66 |  | -        out.copy_(q) | 
| 67 |  | -        out[0] += 1 | 
| 68 |  | -    else: | 
| 69 |  | -        # Standard mode for test_multiple_graphs.py and test_toy_llama.py | 
| 70 |  | -        out.copy_(q) | 
| 71 |  | -        out += k | 
| 72 |  | -        out += v | 
|  | 52 | +    # Unified implementation that depends on all inputs | 
|  | 53 | +    out.copy_(q + k + v) | 
| 73 | 54 | 
 | 
| 74 | 55 | 
 | 
| 75 | 56 | def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, | 
|  | 
0 commit comments