Skip to content

Commit 02d05b4

Browse files
authored
Make use of the SYCL queue passed to CUTLASS (#3)
Use the SYCL queue passed to the function generated for SYCL CUTLASS. It is used for two purposes: - Get the compute unit count of the device actually used - Pass to GemmUniversialAdapter run and initialize for sync. Signed-off-by: Lukas Sommer <lukas.sommer@codeplay.com>
1 parent e28cd77 commit 02d05b4

File tree

2 files changed

+4
-11
lines changed

2 files changed

+4
-11
lines changed

torch/_inductor/codegen/xpu/gemm_template.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,7 @@
4242
using coord_t = cutlass::gemm::GemmCoord::Index;
4343
static cutlass::KernelHardwareInfo hw_info;
4444
45-
// TODO (SYCL) : device_id here is only used for hw info and doesn't necessarly mean
46-
// it's linked to the SYCL queue. It's hardcoded to 0 in the CUDA version as well.
47-
const int device_id = 0;
45+
const int device_id = syclcompat::get_device_id(stream->get_device());
4846
4947
if (hw_info.sm_count == 0) {
5048
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(device_id);
@@ -75,15 +73,11 @@
7573
#endif
7674
#endif
7775
{
78-
// TODO (SYCL): Pass the SYCL queue (currently last arg of `kernel_call_signature` above)
79-
// once supported on CUTLASS side. Variable name to respect the naming in: _EXTRA_CPP_ARGS (sycl_kernel.py)
80-
auto status = gemm_op.initialize(arguments, workspace);
76+
auto status = gemm_op.initialize(arguments, workspace, stream);
8177
CUTLASS_CHECK(status);
8278
}
8379
{
84-
// TODO (SYCL): Pass the SYCL queue once supported on CUTLASS side.
85-
// Variable name to respect the naming in: _EXTRA_CPP_ARGS (sycl_kernel.py)
86-
auto status = gemm_op.run();
80+
auto status = gemm_op.run(stream);
8781
CUTLASS_CHECK(status);
8882
syclcompat::wait_and_throw();
8983
}

torch/_inductor/codegen/xpu/sycl_kernel.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,7 @@ class SYCLTemplateKernel(SYCLKernel):
154154
Template kernels defined by SYCL / Cutlass in C++.
155155
"""
156156

157-
# TODO (SYCL): The SYCL queue is not being used
158-
_EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, sycl::queue stream"
157+
_EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, sycl::queue* stream"
159158

160159
def __init__(
161160
self,

0 commit comments

Comments
 (0)