|
| 1 | +import itertools |
| 2 | +from typing import Optional, Tuple, Union |
| 3 | + |
| 4 | +import torch |
| 5 | +import triton |
| 6 | +from flashinfer.norm import fused_add_rmsnorm, rmsnorm |
| 7 | +from torch import nn |
| 8 | + |
| 9 | +from vllm import _custom_ops as vllm_ops |
| 10 | + |
| 11 | + |
| 12 | +class HuggingFaceRMSNorm(nn.Module): |
| 13 | + |
| 14 | + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: |
| 15 | + super().__init__() |
| 16 | + self.weight = nn.Parameter(torch.ones(hidden_size)) |
| 17 | + self.variance_epsilon = eps |
| 18 | + |
| 19 | + def forward( |
| 20 | + self, |
| 21 | + x: torch.Tensor, |
| 22 | + residual: Optional[torch.Tensor] = None, |
| 23 | + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| 24 | + orig_dtype = x.dtype |
| 25 | + x = x.to(torch.float32) |
| 26 | + if residual is not None: |
| 27 | + x = x + residual.to(torch.float32) |
| 28 | + residual = x.to(orig_dtype) |
| 29 | + |
| 30 | + variance = x.pow(2).mean(dim=-1, keepdim=True) |
| 31 | + x = x * torch.rsqrt(variance + self.variance_epsilon) |
| 32 | + x = x.to(orig_dtype) * self.weight |
| 33 | + if residual is None: |
| 34 | + return x |
| 35 | + else: |
| 36 | + return x, residual |
| 37 | + |
| 38 | + |
| 39 | +def rmsnorm_naive( |
| 40 | + x: torch.Tensor, |
| 41 | + weight: torch.Tensor, |
| 42 | + residual: Optional[torch.Tensor] = None, |
| 43 | + eps: float = 1e-6, |
| 44 | +): |
| 45 | + naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) |
| 46 | + naive_norm.weight = nn.Parameter(weight) |
| 47 | + naive_norm = naive_norm.to(x.device) |
| 48 | + |
| 49 | + orig_shape = x.shape |
| 50 | + x = x.view(-1, x.shape[-1]) |
| 51 | + if residual is not None: |
| 52 | + residual = residual.view(-1, residual.shape[-1]) |
| 53 | + |
| 54 | + output = naive_norm(x, residual) |
| 55 | + |
| 56 | + if isinstance(output, tuple): |
| 57 | + output = (output[0].view(orig_shape), output[1].view(orig_shape)) |
| 58 | + else: |
| 59 | + output = output.view(orig_shape) |
| 60 | + return output |
| 61 | + |
| 62 | + |
| 63 | +def rmsnorm_flashinfer( |
| 64 | + x: torch.Tensor, |
| 65 | + weight: torch.Tensor, |
| 66 | + residual: Optional[torch.Tensor] = None, |
| 67 | + eps: float = 1e-6, |
| 68 | +): |
| 69 | + orig_shape = x.shape |
| 70 | + x = x.view(-1, x.shape[-1]) |
| 71 | + if residual is not None: |
| 72 | + residual = residual.view(-1, residual.shape[-1]) |
| 73 | + |
| 74 | + if residual is not None: |
| 75 | + fused_add_rmsnorm(x, residual, weight, eps) |
| 76 | + output = (x, residual) |
| 77 | + else: |
| 78 | + output = rmsnorm(x, weight, eps) |
| 79 | + |
| 80 | + if isinstance(output, tuple): |
| 81 | + output = (output[0].view(orig_shape), output[1].view(orig_shape)) |
| 82 | + else: |
| 83 | + output = output.view(orig_shape) |
| 84 | + return output |
| 85 | + |
| 86 | + |
| 87 | +def rmsnorm_vllm( |
| 88 | + x: torch.Tensor, |
| 89 | + weight: torch.Tensor, |
| 90 | + residual: Optional[torch.Tensor] = None, |
| 91 | + eps: float = 1e-6, |
| 92 | +): |
| 93 | + orig_shape = x.shape |
| 94 | + x = x.view(-1, x.shape[-1]) |
| 95 | + if residual is not None: |
| 96 | + residual = residual.view(-1, residual.shape[-1]) |
| 97 | + |
| 98 | + if residual is not None: |
| 99 | + vllm_ops.fused_add_rms_norm(x, residual, weight, eps) |
| 100 | + output = (x, residual) |
| 101 | + else: |
| 102 | + out = torch.empty_like(x) |
| 103 | + vllm_ops.rms_norm(out, x, weight, eps) |
| 104 | + output = out |
| 105 | + |
| 106 | + if isinstance(output, tuple): |
| 107 | + output = (output[0].view(orig_shape), output[1].view(orig_shape)) |
| 108 | + else: |
| 109 | + output = output.view(orig_shape) |
| 110 | + return output |
| 111 | + |
| 112 | + |
| 113 | +def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): |
| 114 | + dtype = torch.bfloat16 |
| 115 | + x = torch.randn(batch_size, |
| 116 | + seq_len, |
| 117 | + hidden_size, |
| 118 | + dtype=dtype, |
| 119 | + device="cuda") |
| 120 | + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") |
| 121 | + residual = torch.randn_like(x) if use_residual else None |
| 122 | + |
| 123 | + output_naive = rmsnorm_naive( |
| 124 | + x.clone(), weight, |
| 125 | + residual.clone() if residual is not None else None) |
| 126 | + output_flashinfer = rmsnorm_flashinfer( |
| 127 | + x.clone(), weight, |
| 128 | + residual.clone() if residual is not None else None) |
| 129 | + output_vllm = rmsnorm_vllm( |
| 130 | + x.clone(), weight, |
| 131 | + residual.clone() if residual is not None else None) |
| 132 | + |
| 133 | + if use_residual: |
| 134 | + output_naive = output_naive[0] |
| 135 | + output_flashinfer = output_flashinfer[0] |
| 136 | + output_vllm = output_vllm[0] |
| 137 | + |
| 138 | + print(f"Naive output={output_naive}") |
| 139 | + print(f"FlashInfer output={output_flashinfer}") |
| 140 | + print(f"VLLM output={output_vllm}") |
| 141 | + |
| 142 | + if torch.allclose(output_naive, output_flashinfer, atol=1e-2, |
| 143 | + rtol=1e-2) and torch.allclose( |
| 144 | + output_naive, output_vllm, atol=1e-2, rtol=1e-2): |
| 145 | + print("✅ All implementations match") |
| 146 | + else: |
| 147 | + print("❌ Implementations differ") |
| 148 | + |
| 149 | + |
| 150 | +batch_size_range = [2**i for i in range(0, 7, 2)] |
| 151 | +seq_length_range = [2**i for i in range(6, 11, 1)] |
| 152 | +head_num_range = [32, 48] |
| 153 | +configs = list( |
| 154 | + itertools.product(head_num_range, batch_size_range, seq_length_range)) |
| 155 | + |
| 156 | + |
| 157 | +def get_benchmark(use_residual): |
| 158 | + |
| 159 | + @triton.testing.perf_report( |
| 160 | + triton.testing.Benchmark( |
| 161 | + x_names=["head_num", "batch_size", "seq_len"], |
| 162 | + x_vals=[list(_) for _ in configs], |
| 163 | + line_arg="provider", |
| 164 | + line_vals=["huggingface", "flashinfer", "vllm"], |
| 165 | + line_names=["HuggingFace", "FlashInfer", "vLLM"], |
| 166 | + styles=[("blue", "-"), ("green", "-"), ("red", "-")], |
| 167 | + ylabel="us", |
| 168 | + plot_name= |
| 169 | + f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual", |
| 170 | + args={}, |
| 171 | + )) |
| 172 | + def benchmark(head_num, batch_size, seq_len, provider): |
| 173 | + dtype = torch.bfloat16 |
| 174 | + hidden_size = head_num * 128 # assuming head_dim = 128 |
| 175 | + |
| 176 | + x = torch.randn(batch_size, |
| 177 | + seq_len, |
| 178 | + hidden_size, |
| 179 | + dtype=dtype, |
| 180 | + device="cuda") |
| 181 | + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") |
| 182 | + residual = torch.randn_like(x) if use_residual else None |
| 183 | + |
| 184 | + quantiles = [0.5, 0.2, 0.8] |
| 185 | + |
| 186 | + if provider == "huggingface": |
| 187 | + ms, min_ms, max_ms = triton.testing.do_bench( |
| 188 | + lambda: rmsnorm_naive( |
| 189 | + x.clone(), |
| 190 | + weight, |
| 191 | + residual.clone() if residual is not None else None, |
| 192 | + ), |
| 193 | + quantiles=quantiles, |
| 194 | + ) |
| 195 | + elif provider == "flashinfer": |
| 196 | + ms, min_ms, max_ms = triton.testing.do_bench( |
| 197 | + lambda: rmsnorm_flashinfer( |
| 198 | + x.clone(), |
| 199 | + weight, |
| 200 | + residual.clone() if residual is not None else None, |
| 201 | + ), |
| 202 | + quantiles=quantiles, |
| 203 | + ) |
| 204 | + else: |
| 205 | + ms, min_ms, max_ms = triton.testing.do_bench( |
| 206 | + lambda: rmsnorm_vllm( |
| 207 | + x.clone(), |
| 208 | + weight, |
| 209 | + residual.clone() if residual is not None else None, |
| 210 | + ), |
| 211 | + quantiles=quantiles, |
| 212 | + ) |
| 213 | + |
| 214 | + return 1000 * ms, 1000 * max_ms, 1000 * min_ms |
| 215 | + |
| 216 | + return benchmark |
| 217 | + |
| 218 | + |
| 219 | +if __name__ == "__main__": |
| 220 | + import argparse |
| 221 | + |
| 222 | + parser = argparse.ArgumentParser() |
| 223 | + parser.add_argument( |
| 224 | + "--batch-size", |
| 225 | + type=int, |
| 226 | + default=4, |
| 227 | + help="Batch size", |
| 228 | + ) |
| 229 | + parser.add_argument( |
| 230 | + "--seq-len", |
| 231 | + type=int, |
| 232 | + default=128, |
| 233 | + help="Sequence length", |
| 234 | + ) |
| 235 | + parser.add_argument( |
| 236 | + "--hidden-size", |
| 237 | + type=int, |
| 238 | + default=4096, |
| 239 | + help="Hidden size (2nd dimension) of the sequence", |
| 240 | + ) |
| 241 | + parser.add_argument("--use-residual", |
| 242 | + action="store_true", |
| 243 | + help="Whether to use residual connection") |
| 244 | + parser.add_argument( |
| 245 | + "--save-path", |
| 246 | + type=str, |
| 247 | + default="./configs/rmsnorm/", |
| 248 | + help="Path to save rmsnorm benchmark results", |
| 249 | + ) |
| 250 | + |
| 251 | + args = parser.parse_args() |
| 252 | + |
| 253 | + # Run correctness test |
| 254 | + calculate_diff(batch_size=args.batch_size, |
| 255 | + seq_len=args.seq_len, |
| 256 | + hidden_size=args.hidden_size, |
| 257 | + use_residual=args.use_residual) |
| 258 | + |
| 259 | + # Get the benchmark function with proper use_residual setting |
| 260 | + benchmark = get_benchmark(args.use_residual) |
| 261 | + # Run performance benchmark |
| 262 | + benchmark.run(print_data=True, save_path=args.save_path) |
0 commit comments