Open
Description
Summary
Recently triton added the scaled_dot
which consumes A, B in f8, f6, f4 packed in int32 format and u8m0 scales via int8 datatype. https://github.com/triton-lang/triton/pull/4795/files#diff-1d96a0ed473569188c00d6e16c54dd7050e0a66040438ac630c889aef7cbbbe8R1544
Steps
- Implement new mx matmul in triton | add utilities to ensure that this op is only available when new enough triton is used
- Write unit tests verifying the correctness of implementation against the existing
upcast and matmul
approach - Update Mx Tensor's dispatch to (based on config) use the new op instead of upcasting and running in original precision:
ao/torchao/prototype/mx_formats/mx_ops.py
Lines 64 to 68 in 48bc81c
- Create profile + memory traces