-
Notifications
You must be signed in to change notification settings - Fork 336
Open
Labels
enhancementNew feature or requestNew feature or requestgood first issueGood for newcomersGood for newcomersmx
Description
Summary
Recently triton added the scaled_dot
which consumes A, B in f8, f6, f4 packed in int32 format and u8m0 scales via int8 datatype. https://github.com/triton-lang/triton/pull/4795/files#diff-1d96a0ed473569188c00d6e16c54dd7050e0a66040438ac630c889aef7cbbbe8R1544
Steps
- Implement new mx matmul in triton | add utilities to ensure that this op is only available when new enough triton is used
- Write unit tests verifying the correctness of implementation against the existing
upcast and matmul
approach - Update Mx Tensor's dispatch to (based on config) use the new op instead of upcasting and running in original precision:
ao/torchao/prototype/mx_formats/mx_ops.py
Lines 64 to 68 in 48bc81c
b = args[1] assert isinstance(a, MXTensor) and isinstance(b, MXTensor) a_hp = a.to_dtype(a._orig_dtype) b_hp = b.to_dtype(b._orig_dtype) res = aten_op(a_hp, b_hp) - Create profile + memory traces
vkuzo, gau-nernst, msaroufim and jerryzh168
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or requestgood first issueGood for newcomersGood for newcomersmx