Skip to content

Commit e21c49d

Browse files
committed
Enable autotuning for SYCL CUTLASS
Enable autotuning for SYCL CUTLASS by completing the SYCL benchmark request class. Also adds a temporary workaround to allow bf16 GEMM to accumulate in FP32 in code paths used when auto-tuning is active. Signed-off-by: Lukas Sommer <lukas.sommer@codeplay.com>
1 parent e28cd77 commit e21c49d

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

torch/_inductor/autotune_process.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -889,8 +889,7 @@ def get_tuning_process_pool() -> TuningProcessPool:
889889

890890
class SYCLBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
891891
# Important: Instances of this class have to be serializable
892-
# across process boundaries. Do not put Tensors in here!
893-
# TODO (SYCL) : Complete the bmrq class to enable full autotuning
892+
# across process boundaries. Do not put device tensors in here!
894893
def __init__(
895894
self,
896895
kernel_name: str,

torch/_inductor/codegen/xpu/sycl_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -468,8 +468,8 @@ def precompile(self) -> None:
468468
self.bmreq.precompile()
469469

470470
def benchmark(self, *args, out) -> float:
471-
# TODO (SYCL) : Enable benchmarking once supported
472-
return 0.001
471+
assert self.bmreq is not None
472+
return self.bmreq.benchmark(*args, output_tensor=out)
473473

474474
def __str__(self) -> str:
475475
return f"SYCLTemplateCaller(source_file={self.bmreq.source_file})"

torch/_inductor/scheduler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2715,7 +2715,10 @@ def replace_operation_buffer(
27152715
out_buffer = out_storage.data
27162716
assert isinstance(out_buffer, ir.OperationBuffer)
27172717

2718-
out_buffer.layout = multi_node.layout
2718+
# TODO (SYCL): This is a temporary hack to allow auto-tuning
2719+
# while our CUTLASS does not support bf16 accumulation for
2720+
# GEMM. Uncomment this line when it is supported.
2721+
#out_buffer.layout = multi_node.layout
27192722
replace_operation_buffer(multi_node, out_buffer)
27202723
new_scheduler_node = self.create_scheduler_node(out_buffer)
27212724

0 commit comments

Comments
 (0)