- 
                Notifications
    You must be signed in to change notification settings 
- Fork 356
Open
Labels
enhancementNew feature or requestNew feature or requestgood first issueGood for newcomersGood for newcomersmx
Description
Summary
Recently triton added the scaled_dot which consumes A, B in f8, f6, f4  packed in int32 format and u8m0 scales via int8 datatype. https://github.com/triton-lang/triton/pull/4795/files#diff-1d96a0ed473569188c00d6e16c54dd7050e0a66040438ac630c889aef7cbbbe8R1544
Steps
- Implement new mx matmul in triton | add utilities to ensure that this op is only available when new enough triton is used
- Write unit tests verifying the correctness of implementation against the existing upcast and matmulapproach
- Update Mx Tensor's dispatch to (based on config) use the new op instead of upcasting and running in original precision: ao/torchao/prototype/mx_formats/mx_ops.py Lines 64 to 68 in 48bc81c b = args[1] assert isinstance(a, MXTensor) and isinstance(b, MXTensor) a_hp = a.to_dtype(a._orig_dtype) b_hp = b.to_dtype(b._orig_dtype) res = aten_op(a_hp, b_hp) 
- Create profile + memory traces
vkuzo, gau-nernst, msaroufim and jerryzh168
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or requestgood first issueGood for newcomersGood for newcomersmx