Skip to content

Commit d76676d

Browse files
committed
Enable bf16 accumulation
Signed-off-by: Lukas Sommer <lukas.sommer@codeplay.com>
1 parent ff5972c commit d76676d

File tree

3 files changed

+3
-10
lines changed

3 files changed

+3
-10
lines changed

torch/_inductor/codegen/xpu/cutlass_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def get_accumulator_dtype(
203203
return None
204204

205205
if all(dtype == torch.bfloat16 for dtype in input_torch_dtypes):
206-
return torch.float
206+
return torch.bfloat16
207207
else:
208208
raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes}")
209209

torch/_inductor/codegen/xpu/gemm_template.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -652,11 +652,7 @@ def __init__(
652652
beta: float,
653653
input_reorder: Optional[list[int]] = None,
654654
):
655-
# TODO (SYCL) : This is a workaround hardcoding output type (layout) to float32
656-
# Should be removed once not limited to the bfloat input->float32 accum cutlass configurations
657-
float_layout = copy.deepcopy(layout)
658-
float_layout.dtype = float32
659-
super().__init__(input_nodes, float_layout, alpha, beta, input_reorder)
655+
super().__init__(input_nodes, layout, alpha, beta, input_reorder)
660656

661657
@staticmethod
662658
def add_cutlass_gemm_choices(

torch/_inductor/scheduler.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2715,10 +2715,7 @@ def replace_operation_buffer(
27152715
out_buffer = out_storage.data
27162716
assert isinstance(out_buffer, ir.OperationBuffer)
27172717

2718-
# TODO (SYCL): This is a temporary hack to allow auto-tuning
2719-
# while our CUTLASS does not support bf16 accumulation for
2720-
# GEMM. Uncomment this line when it is supported.
2721-
#out_buffer.layout = multi_node.layout
2718+
out_buffer.layout = multi_node.layout
27222719
replace_operation_buffer(multi_node, out_buffer)
27232720
new_scheduler_node = self.create_scheduler_node(out_buffer)
27242721

0 commit comments

Comments
 (0)