|
| 1 | +import torch |
| 2 | + |
| 3 | +import triton |
| 4 | +import triton.language as tl |
| 5 | + |
| 6 | + |
| 7 | +@triton.jit |
| 8 | +def _per_head_max_reduce_kernel( |
| 9 | + Q, |
| 10 | + Scales, |
| 11 | + BatchIds, |
| 12 | + StartLoc, |
| 13 | + stride_q_t, |
| 14 | + stride_q_h, |
| 15 | + stride_scales_b, |
| 16 | + SET_BATCH_IDS: tl.constexpr, |
| 17 | + FP8_MAX: tl.constexpr, |
| 18 | + BLOCK_T: tl.constexpr, |
| 19 | + BLOCK_D: tl.constexpr, |
| 20 | +): |
| 21 | + b_id = tl.program_id(0) |
| 22 | + h_id = tl.program_id(1) |
| 23 | + |
| 24 | + max_val = 0.0 |
| 25 | + |
| 26 | + start_loc = tl.load(StartLoc + b_id) |
| 27 | + end_loc = tl.load(StartLoc + b_id + 1) |
| 28 | + for t_offset in range(start_loc, end_loc, BLOCK_T): |
| 29 | + t_idx = t_offset + tl.arange(0, BLOCK_T) |
| 30 | + q_range = tl.arange(0, BLOCK_D) |
| 31 | + q_ptrs = Q + t_idx[:, None] * stride_q_t + h_id * stride_q_h + q_range[None, :] |
| 32 | + mask = (t_idx[:, None] < end_loc) & (q_range[None, :] < stride_q_h) |
| 33 | + q_vals = tl.load(q_ptrs, mask=mask, other=0.0) |
| 34 | + max_val = tl.maximum(tl.max(q_vals.abs()), max_val) |
| 35 | + if SET_BATCH_IDS: |
| 36 | + tl.store(BatchIds + t_idx, b_id, mask=t_idx < end_loc) |
| 37 | + |
| 38 | + scale = tl.where(max_val > 0, max_val / FP8_MAX, 1.0) |
| 39 | + scale_ptr = Scales + b_id * stride_scales_b + h_id |
| 40 | + tl.store(scale_ptr, scale) |
| 41 | + |
| 42 | + |
| 43 | +@triton.jit |
| 44 | +def _apply_quantization_kernel( |
| 45 | + Q, |
| 46 | + Q_out, |
| 47 | + BatchIds, |
| 48 | + Scales, |
| 49 | + stride_q_t, |
| 50 | + stride_q_h, |
| 51 | + stride_qout_t, |
| 52 | + stride_qout_h, |
| 53 | + stride_scales_b, |
| 54 | + FP8_MIN: tl.constexpr, |
| 55 | + FP8_MAX: tl.constexpr, |
| 56 | + BLOCK_D: tl.constexpr, |
| 57 | +): |
| 58 | + t_id = tl.program_id(0) |
| 59 | + h_id = tl.program_id(1) |
| 60 | + |
| 61 | + batch_id = tl.load(BatchIds + t_id) |
| 62 | + scale_ptr = Scales + batch_id * stride_scales_b + h_id |
| 63 | + scale = tl.load(scale_ptr) |
| 64 | + |
| 65 | + q_range = tl.arange(0, BLOCK_D) |
| 66 | + q_ptrs = Q + t_id * stride_q_t + h_id * stride_q_h + q_range |
| 67 | + qout_ptrs = Q_out + t_id * stride_qout_t + h_id * stride_qout_h + q_range |
| 68 | + mask = q_range < stride_q_h |
| 69 | + q_vals = tl.load(q_ptrs, mask=mask, other=0.0) |
| 70 | + q_scaled = q_vals / scale |
| 71 | + q_clamped = tl.clamp(q_scaled, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv) |
| 72 | + tl.store(qout_ptrs, q_clamped, mask=q_range < stride_qout_h) |
| 73 | + |
| 74 | + |
| 75 | +@torch.no_grad() |
| 76 | +def q_per_head_fp8_quant(q, seq_lens, b1_start_loc): |
| 77 | + T, H, D = q.shape |
| 78 | + B = seq_lens.shape[0] |
| 79 | + device = q.device |
| 80 | + |
| 81 | + q_out = torch.empty_like(q, dtype=torch.float8_e4m3fn) |
| 82 | + scales = torch.empty((B, H), dtype=torch.float32, device=device) |
| 83 | + batch_ids = torch.zeros((T,), dtype=torch.int32, device=device) |
| 84 | + |
| 85 | + BLOCK_D = triton.next_power_of_2(D) |
| 86 | + BLOCK_T = 256 |
| 87 | + num_warps = 4 |
| 88 | + num_stages = 2 |
| 89 | + _per_head_max_reduce_kernel[(B, H)]( |
| 90 | + q, |
| 91 | + scales, |
| 92 | + batch_ids, |
| 93 | + b1_start_loc, |
| 94 | + q.stride(0), |
| 95 | + q.stride(1), |
| 96 | + scales.stride(0), |
| 97 | + FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, |
| 98 | + SET_BATCH_IDS=B > 1, |
| 99 | + BLOCK_T=BLOCK_T, |
| 100 | + BLOCK_D=BLOCK_D, |
| 101 | + num_warps=num_warps, |
| 102 | + num_stages=num_stages, |
| 103 | + ) |
| 104 | + |
| 105 | + _apply_quantization_kernel[(T, H)]( |
| 106 | + q, |
| 107 | + q_out, |
| 108 | + batch_ids, |
| 109 | + scales, |
| 110 | + q.stride(0), |
| 111 | + q.stride(1), |
| 112 | + q_out.stride(0), |
| 113 | + q_out.stride(1), |
| 114 | + scales.stride(0), |
| 115 | + FP8_MIN=torch.finfo(torch.float8_e4m3fn).min, |
| 116 | + FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, |
| 117 | + BLOCK_D=BLOCK_D, |
| 118 | + num_warps=num_warps, |
| 119 | + num_stages=num_stages, |
| 120 | + ) |
| 121 | + return q_out, scales |
| 122 | + |
| 123 | + |
| 124 | +def ref_q_per_head_fp8_quant(q, seq_lens): |
| 125 | + min_fp8 = torch.finfo(torch.float8_e4m3fn).min |
| 126 | + max_fp8 = torch.finfo(torch.float8_e4m3fn).max |
| 127 | + B = seq_lens.size(0) |
| 128 | + device = q.device |
| 129 | + batch_ids = torch.repeat_interleave(torch.arange(B, device=device), seq_lens) |
| 130 | + max_per_time_head = q.abs().amax(dim=2) |
| 131 | + max_per_bh = torch.zeros((B, max_per_time_head.size(1)), device=device, dtype=max_per_time_head.dtype) |
| 132 | + max_per_bh.scatter_reduce_( |
| 133 | + 0, |
| 134 | + batch_ids.unsqueeze(-1).expand(-1, max_per_time_head.size(1)), |
| 135 | + max_per_time_head, |
| 136 | + reduce="amax", |
| 137 | + include_self=False, |
| 138 | + ) |
| 139 | + scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)).to(torch.float32) |
| 140 | + scale_expanded = scales[batch_ids].view(-1, scales.size(1), 1) |
| 141 | + q_q = (q / scale_expanded).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn) |
| 142 | + return q_q, scales |
| 143 | + |
| 144 | + |
| 145 | +if __name__ == "__main__": |
| 146 | + B, T, H, D = 200, 1000, 4, 7 * 128 |
| 147 | + seq_lens = torch.ones((B,), dtype=torch.int32).cuda() * T // B |
| 148 | + start_locs = torch.zeros(B + 1, dtype=torch.int32).cuda() |
| 149 | + start_locs[1:] = seq_lens.cumsum(dim=0) |
| 150 | + q = torch.randn((T, H, D), dtype=torch.float32).cuda() |
| 151 | + |
| 152 | + q_out, scales = q_per_head_fp8_quant(q, seq_lens, start_locs) |
| 153 | + q_out1, scales1 = ref_q_per_head_fp8_quant(q, seq_lens) |
| 154 | + assert torch.allclose(scales, scales1, atol=1e-10, rtol=0) |
| 155 | + assert torch.allclose(q_out.int(), q_out1.int(), atol=1e-10, rtol=0) |
0 commit comments