Commit 0ec4969
uses current CUDAStream correctly (#118)
* Fix CUDA stream synchronization in custom kernels
This commit fixes GitHub issue pytorch/pytorch#157363 where custom CUDA
kernels were not properly synchronized with PyTorch's CUDA stream when
used with torch.compile in reduce-overhead mode.
Changes:
- Add #include <ATen/cuda/CUDAContext.h> for getCurrentCUDAStream()
- Use at::cuda::getCurrentCUDAStream() to get PyTorch's current CUDA stream
- Launch all kernels with the correct stream parameter
The issue occurred because custom kernels launched on the default CUDA stream
while PyTorch operations (like nn.Linear) run on PyTorch's managed stream.
This created race conditions where custom kernels would execute before
PyTorch operations completed, resulting in incorrect output values.
With this fix, all custom kernels are properly synchronized with PyTorch's
CUDA stream, ensuring correct execution order and preventing race conditions
when used with torch.compile.
🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>
* Add tests for torch.compile stream synchronization fix
Added comprehensive tests to verify the fix for GitHub issue pytorch/pytorch#157363:
1. test_compile_with_linear_layer:
- Tests custom CUDA kernels with nn.Linear + torch.compile
- Verifies correct behavior with various input sizes (1000, 5000, 10000)
- Uses reduce-overhead mode to reproduce the original issue conditions
2. test_compile_custom_only:
- Tests custom operations without linear layers
- Ensures custom operations work correctly with torch.compile
These tests ensure that custom CUDA kernels properly synchronize with
PyTorch's CUDA stream when used with torch.compile, preventing race
conditions that previously caused incorrect outputs.
🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>
* Use self.assertEqual instead of torch.testing.assert_close in tests
Replace manual tolerance specification with self.assertEqual which
automatically handles appropriate tolerances for tensor comparisons.
This makes the tests more concise and follows PyTorch testing conventions.
🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>
---------
Co-authored-by: Claude <noreply@anthropic.com>1 parent e9a1c5d commit 0ec4969
2 files changed
+57
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
| 7 | + | |
7 | 8 | | |
8 | 9 | | |
9 | 10 | | |
| |||
26 | 27 | | |
27 | 28 | | |
28 | 29 | | |
29 | | - | |
| 30 | + | |
| 31 | + | |
30 | 32 | | |
31 | 33 | | |
32 | 34 | | |
| |||
48 | 50 | | |
49 | 51 | | |
50 | 52 | | |
51 | | - | |
| 53 | + | |
| 54 | + | |
52 | 55 | | |
53 | 56 | | |
54 | 57 | | |
| |||
73 | 76 | | |
74 | 77 | | |
75 | 78 | | |
76 | | - | |
| 79 | + | |
| 80 | + | |
77 | 81 | | |
78 | 82 | | |
79 | 83 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
| 9 | + | |
9 | 10 | | |
10 | 11 | | |
11 | 12 | | |
| |||
119 | 120 | | |
120 | 121 | | |
121 | 122 | | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
122 | 172 | | |
123 | 173 | | |
0 commit comments