Skip to content

Commit dd3c836

Browse files
authored
bugfix: Fix compile error of OptionalCUDAGuard and device_of (#613)
There are some compile errors in the main branch, like ``` /app/python/csrc_aot/single_prefill.cu(59): error: namespace "at::cuda" has no member "OptionalCUDAGuard" const at::cuda::OptionalCUDAGuard device_guard(device_of(device)); ^ /app/python/csrc_aot/single_prefill.cu(59): error: identifier "device_of" is undefined const at::cuda::OptionalCUDAGuard device_guard(device_of(device)); ``` cc @yzh119
1 parent b53a46f commit dd3c836

File tree

4 files changed

+10
-5
lines changed

4 files changed

+10
-5
lines changed

python/csrc_aot/batch_decode.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
#include <c10/cuda/CUDAGuard.h>
1617
#include <torch/extension.h>
1718

1819
#include <flashinfer/attention/decode_params.cuh>
@@ -42,7 +43,7 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
4243
size_t int_workspace_size_in_bytes =
4344
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
4445
auto device = float_workspace_buffer.device();
45-
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
46+
const at::cuda::CUDAGuard device_guard(device);
4647
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
4748
TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU");
4849

@@ -113,7 +114,7 @@ torch::Tensor BatchDecodeWithPagedKVCacheRun(
113114
}
114115
uint32_t head_dim = q.size(2);
115116

116-
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
117+
const at::cuda::CUDAGuard device_guard(device);
117118
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
118119
torch::Tensor o = torch::empty_like(q);
119120
if (maybe_lse) {

python/csrc_aot/batch_prefill.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
#include <c10/cuda/CUDAGuard.h>
1617
#include <torch/extension.h>
1718

1819
#include <flashinfer/attention/mask.cuh>
@@ -50,7 +51,7 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(
5051
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
5152

5253
auto device = float_workspace_buffer.device();
53-
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
54+
const at::cuda::CUDAGuard device_guard(device);
5455
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
5556
TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU");
5657
TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU");

python/csrc_aot/single_decode.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
#include <c10/cuda/CUDAGuard.h>
17+
1618
#include <flashinfer/attention/decode_params.cuh>
1719
#include <flashinfer/attention/variants.cuh>
1820
#include <optional>
@@ -60,7 +62,7 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc
6062
kv_len = k.size(1);
6163
}
6264
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
63-
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
65+
const at::cuda::CUDAGuard device_guard(device);
6466
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
6567
auto o = torch::empty_like(q);
6668

python/csrc_aot/single_prefill.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
#include <c10/cuda/CUDAGuard.h>
1617
#include <torch/extension.h>
1718

1819
#include <flashinfer/attention/mask.cuh>
@@ -56,7 +57,7 @@ torch::Tensor single_prefill_with_kv_cache(
5657
kv_stride_h = k.stride(0);
5758
kv_stride_n = k.stride(1);
5859
}
59-
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
60+
const at::cuda::CUDAGuard device_guard(device);
6061
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
6162
auto o = torch::empty_like(q, q.options());
6263
if (maybe_lse) {

0 commit comments

Comments
 (0)