Skip to content

Commit bc53ae6

Browse files
authored
Enable autotuning and bf16 accumulation for SYCL CUTLASS (#4)
Enable autotuning for SYCL CUTLASS by completing the SYCL benchmark request class. Also removes a temporary workaround that forced float32 accumulation to now allow GEMM to accumulate in bfloat16. This addresses one of the items left open in #2. --------- Signed-off-by: Lukas Sommer <lukas.sommer@codeplay.com>
1 parent 02d05b4 commit bc53ae6

File tree

5 files changed

+22
-11
lines changed

5 files changed

+22
-11
lines changed

third_party/cutlass

torch/_inductor/autotune_process.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -889,8 +889,7 @@ def get_tuning_process_pool() -> TuningProcessPool:
889889

890890
class SYCLBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
891891
# Important: Instances of this class have to be serializable
892-
# across process boundaries. Do not put Tensors in here!
893-
# TODO (SYCL) : Complete the bmrq class to enable full autotuning
892+
# across process boundaries. Do not put device tensors in here!
894893
def __init__(
895894
self,
896895
kernel_name: str,

torch/_inductor/codegen/xpu/cutlass_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def get_accumulator_dtype(
203203
return None
204204

205205
if all(dtype == torch.bfloat16 for dtype in input_torch_dtypes):
206-
return torch.float
206+
return torch.bfloat16
207207
else:
208208
raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes}")
209209

torch/_inductor/codegen/xpu/gemm_template.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -652,11 +652,7 @@ def __init__(
652652
beta: float,
653653
input_reorder: Optional[list[int]] = None,
654654
):
655-
# TODO (SYCL) : This is a workaround hardcoding output type (layout) to float32
656-
# Should be removed once not limited to the bfloat input->float32 accum cutlass configurations
657-
float_layout = copy.deepcopy(layout)
658-
float_layout.dtype = float32
659-
super().__init__(input_nodes, float_layout, alpha, beta, input_reorder)
655+
super().__init__(input_nodes, layout, alpha, beta, input_reorder)
660656

661657
@staticmethod
662658
def add_cutlass_gemm_choices(
@@ -780,14 +776,30 @@ def _set_bias_layout_and_alignment(
780776
self,
781777
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
782778
) -> bool:
779+
import cutlass_library.library as cutlass_lib
780+
783781
has_bias = len(self.input_nodes) >= 3 and self.input_nodes[2] is not None
784782
if has_bias:
785783
bias = self.input_nodes[2]
784+
# Bias data type
785+
op.C.element = cutlass_utils.torch_dtype_to_cutlass_type(
786+
bias.get_layout().dtype
787+
)
788+
assert op.C.element == op.D.element, (
789+
f"Expect C and D to have the same dtype, found {op.C.element} and {op.D.element}"
790+
)
791+
792+
# Bias layout
786793
bias_layout = CUTLASSGemmTemplate.cutlass_layout(bias.get_layout())
787794
op.C.layout = bias_layout
795+
796+
# Bias alignment
788797
status = self.set_alignment(bias.get_layout(), op.C)
789798
if not status:
790799
return False
800+
801+
else:
802+
op.C.element = cutlass_lib.DataType.void
791803
return True
792804

793805
def _dtype_match(

torch/_inductor/codegen/xpu/sycl_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,8 +467,8 @@ def precompile(self) -> None:
467467
self.bmreq.precompile()
468468

469469
def benchmark(self, *args, out) -> float:
470-
# TODO (SYCL) : Enable benchmarking once supported
471-
return 0.001
470+
assert self.bmreq is not None
471+
return self.bmreq.benchmark(*args, output_tensor=out)
472472

473473
def __str__(self) -> str:
474474
return f"SYCLTemplateCaller(source_file={self.bmreq.source_file})"

0 commit comments

Comments
 (0)