Add Intel XPU support to Triton backend#24
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 31da6a64ba
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| ), | ||
| "per_tensor_scale": ParamConstraint(dtypes=frozenset({torch.float32})), | ||
| }, | ||
| default_devices=cuda_devices, | ||
| default_devices=triton_devices, | ||
| ), |
There was a problem hiding this comment.
Exclude NVFP4 Triton quantize from XPU selection
The new triton_devices default includes xpu for quantize_nvfp4, so on XPU-only systems the registry will select the Triton backend and attempt to run this kernel. However the Triton NVFP4 quantize kernel uses inline PTX (tl.inline_asm_elementwise with cvt.rn.satfinite.e2m1x2.f32 in backends/triton/quantization.py), which is CUDA-specific and not supported on Intel XPU. This means NVFP4 quantization will fail at runtime on XPU instead of falling back to eager as the commit message claims. Consider restricting quantize_nvfp4 to CUDA-only devices (or adding a hardware gate similar to dequantize_nvfp4) so XPU routes to the eager backend.
Useful? React with 👍 / 👎.
- Add 'xpu' to supported devices in Triton backend constraints
- Detect XPU availability alongside CUDA in backend registration
- Triton kernels (FP8, RoPE) now work on Intel Arc/Data Center GPUs
NVFP4 operations still require CUDA (PTX assembly) and fall back to eager.
Tested on Intel Arc Pro B60 (PyTorch 2.11.0.dev+xpu):
FP8 Quantize (bf16 -> fp8_e4m3):
256x256: 0.022ms (2.9 GE/s)
1024x1024: 0.022ms (48.3 GE/s)
4096x4096: 0.215ms (78.1 GE/s)
FP8 Dequantize (fp8_e4m3 -> bf16):
256x256: 0.021ms (3.1 GE/s)
1024x1024: 0.022ms (48.5 GE/s)
4096x4096: 0.147ms (113.9 GE/s)
RoPE (bf16):
B=1,H=8,S=128,D=64: 0.024ms (2.7 GE/s)
B=2,H=32,S=1024,D=128: 0.351ms (23.9 GE/s)
B=4,H=32,S=2048,D=128: 0.841ms (39.9 GE/s)
31da6a6 to
200d677
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 31da6a64ba
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
This PR adds Intel XPU (Arc/Data Center GPU) support to the Triton backend.
Changes
NVFP4 operations still require CUDA (PTX assembly) and will fall back to eager mode on XPU.
Benchmarks
Tested on Intel Arc Pro B60 (PyTorch 2.11.0.dev+xpu):