Skip to content

[MX | Triton] Create MX matmul op using new scaled_dot op in Triton #1084

Open
@drisspg

Description

@drisspg

Summary

Recently triton added the scaled_dot which consumes A, B in f8, f6, f4 packed in int32 format and u8m0 scales via int8 datatype. https://github.com/triton-lang/triton/pull/4795/files#diff-1d96a0ed473569188c00d6e16c54dd7050e0a66040438ac630c889aef7cbbbe8R1544

Steps

  1. Implement new mx matmul in triton | add utilities to ensure that this op is only available when new enough triton is used
  2. Write unit tests verifying the correctness of implementation against the existing upcast and matmul approach
  3. Update Mx Tensor's dispatch to (based on config) use the new op instead of upcasting and running in original precision:
    b = args[1]
    assert isinstance(a, MXTensor) and isinstance(b, MXTensor)
    a_hp = a.to_dtype(a._orig_dtype)
    b_hp = b.to_dtype(b._orig_dtype)
    res = aten_op(a_hp, b_hp)
  4. Create profile + memory traces

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions