Skip to content

Commit 8d69392

Browse files
committed
misc changes
Signed-off-by: ca1207 <ca1207zzz@gmail.com>
1 parent 1154e07 commit 8d69392

File tree

8 files changed

+83
-59
lines changed

8 files changed

+83
-59
lines changed

benchmarks/kernels/benchmark_polynorm.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import itertools
5-
from typing import Optional, Union
65

76
import torch
87

@@ -23,8 +22,16 @@ def norm(x, eps: float):
2322
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
2423

2524
x = x.float()
26-
return (weight[0] * norm(x**3, eps) + weight[1] * norm(x**2, eps) +
27-
weight[2] * norm(x, eps) + bias).to(weight.dtype).view(orig_shape)
25+
return (
26+
(
27+
weight[0] * norm(x**3, eps)
28+
+ weight[1] * norm(x**2, eps)
29+
+ weight[2] * norm(x, eps)
30+
+ bias
31+
)
32+
.to(weight.dtype)
33+
.view(orig_shape)
34+
)
2835

2936

3037
def polynorm_vllm(
@@ -44,18 +51,14 @@ def polynorm_vllm(
4451
return output
4552

4653

47-
def calculate_diff(batch_size, seq_len, hidden_size):
54+
def calculate_diff(batch_size, seq_len, hidden_dim):
4855
dtype = torch.bfloat16
49-
x = torch.randn(batch_size,
50-
seq_len,
51-
hidden_size,
52-
dtype=dtype,
53-
device="cuda")
56+
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
5457
weight = torch.ones(3, dtype=dtype, device="cuda")
5558
bais = torch.ones(1, dtype=dtype, device="cuda")
5659

57-
output_naive = polynorm_naive(x.clone(), weight, bais)
58-
output_vllm = polynorm_vllm(x.clone(), weight, bais)
60+
output_naive = polynorm_naive(x, weight, bais)
61+
output_vllm = polynorm_vllm(x, weight, bais)
5962

6063
if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
6164
print("✅ All implementations match")
@@ -65,47 +68,42 @@ def calculate_diff(batch_size, seq_len, hidden_size):
6568

6669
batch_size_range = [2**i for i in range(0, 7, 2)]
6770
seq_length_range = [2**i for i in range(6, 11, 1)]
68-
head_num_range = [32, 48]
69-
configs = list(
70-
itertools.product(head_num_range, batch_size_range, seq_length_range))
71+
dim_range = [2048, 4096]
72+
configs = list(itertools.product(dim_range, batch_size_range, seq_length_range))
7173

7274

7375
def get_benchmark():
74-
7576
@triton.testing.perf_report(
7677
triton.testing.Benchmark(
77-
x_names=["head_num", "batch_size", "seq_len"],
78+
x_names=["dim", "batch_size", "seq_len"],
7879
x_vals=[list(_) for _ in configs],
7980
line_arg="provider",
8081
line_vals=["naive", "vllm"],
8182
line_names=["Naive", "vLLM"],
8283
styles=[("blue", "-"), ("red", "-")],
8384
ylabel="us",
84-
plot_name=f"polynorm-perf",
85+
plot_name="polynorm-perf",
8586
args={},
86-
))
87-
def benchmark(head_num, batch_size, seq_len, provider):
87+
)
88+
)
89+
def benchmark(dim, batch_size, seq_len, provider):
8890
dtype = torch.bfloat16
89-
hidden_size = head_num * 128 # assuming head_dim = 128
91+
hidden_dim = dim * 4
9092

91-
x = torch.randn(batch_size,
92-
seq_len,
93-
hidden_size,
94-
dtype=dtype,
95-
device="cuda")
93+
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
9694
weight = torch.ones(3, dtype=dtype, device="cuda")
9795
bias = torch.ones(1, dtype=dtype, device="cuda")
9896

9997
quantiles = [0.5, 0.2, 0.8]
10098

10199
if provider == "naive":
102100
ms, min_ms, max_ms = triton.testing.do_bench(
103-
lambda: polynorm_naive(x.clone(), weight, bias),
101+
lambda: polynorm_naive(x, weight, bias),
104102
quantiles=quantiles,
105103
)
106104
else:
107105
ms, min_ms, max_ms = triton.testing.do_bench(
108-
lambda: polynorm_vllm(x.clone(), weight, bias),
106+
lambda: polynorm_vllm(x, weight, bias),
109107
quantiles=quantiles,
110108
)
111109

@@ -131,10 +129,10 @@ def benchmark(head_num, batch_size, seq_len, provider):
131129
help="Sequence length",
132130
)
133131
parser.add_argument(
134-
"--hidden-size",
132+
"--hidden-dim",
135133
type=int,
136-
default=4096,
137-
help="Hidden size (2nd dimension) of the sequence",
134+
default=8192,
135+
help="Intermediate size of MLP",
138136
)
139137
parser.add_argument(
140138
"--save-path",
@@ -149,7 +147,7 @@ def benchmark(head_num, batch_size, seq_len, provider):
149147
calculate_diff(
150148
batch_size=args.batch_size,
151149
seq_len=args.seq_len,
152-
hidden_size=args.hidden_size,
150+
hidden_dim=args.hidden_dim,
153151
)
154152

155153
benchmark = get_benchmark()

csrc/layernorm_kernels.cu

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ struct alignas(16) _f16VecPN : _f16Vec<scalar_t, width> {
203203
template <typename scalar_t, int width>
204204
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
205205
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
206-
scalar_t* __restrict__ input, // [..., hidden_size]
206+
const scalar_t* __restrict__ input, // [..., hidden_size]
207207
const scalar_t* __restrict__ weight, // [3]
208208
const scalar_t* __restrict__ bias, // [1]
209209
const float epsilon, const int hidden_size) {
@@ -215,7 +215,7 @@ poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
215215
not aliased in practice. Argument pointers should not be dereferenced
216216
in this kernel as that would be undefined behavior */
217217
auto* __restrict__ input_v =
218-
reinterpret_cast<_f16VecPN<scalar_t, width>*>(input);
218+
reinterpret_cast<const _f16VecPN<scalar_t, width>*>(input);
219219
const int vec_hidden_size = hidden_size / width;
220220
float variance = 0.0f;
221221
float variance2 = 0.0f;
@@ -231,14 +231,22 @@ poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
231231
variance3 += x6;
232232
}
233233

234-
using BlockReduce = cub::BlockReduce<float, 1024>;
234+
float3 thread_variances = make_float3(variance, variance2, variance3);
235+
236+
struct SumOp {
237+
__device__ float3 operator()(const float3& a, const float3& b) const {
238+
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
239+
}
240+
};
241+
242+
using BlockReduce = cub::BlockReduce<float3, 1024>;
235243
__shared__ typename BlockReduce::TempStorage reduceStore;
244+
float3 block_variances =
245+
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
236246

237-
variance = BlockReduce(reduceStore).Sum(variance, blockDim.x);
238-
__syncthreads();
239-
variance2 = BlockReduce(reduceStore).Sum(variance2, blockDim.x);
240-
__syncthreads();
241-
variance3 = BlockReduce(reduceStore).Sum(variance3, blockDim.x);
247+
variance = block_variances.x;
248+
variance2 = block_variances.y;
249+
variance3 = block_variances.z;
242250

243251
__shared__ float s_w2_inv_std;
244252
__shared__ float s_w1_inv_std2;
@@ -273,7 +281,7 @@ poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
273281
template <typename scalar_t, int width>
274282
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
275283
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
276-
scalar_t* __restrict__ input, // [..., hidden_size]
284+
const scalar_t* __restrict__ input, // [..., hidden_size]
277285
const scalar_t* __restrict__ weight, // [3]
278286
const scalar_t* __restrict__ bias, // [1]
279287
const float epsilon, const int hidden_size) {
@@ -292,14 +300,22 @@ poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
292300
variance3 += x6;
293301
}
294302

295-
using BlockReduce = cub::BlockReduce<float, 1024>;
303+
float3 thread_variances = make_float3(variance, variance2, variance3);
304+
305+
struct SumOp {
306+
__device__ float3 operator()(const float3& a, const float3& b) const {
307+
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
308+
}
309+
};
310+
311+
using BlockReduce = cub::BlockReduce<float3, 1024>;
296312
__shared__ typename BlockReduce::TempStorage reduceStore;
313+
float3 block_variances =
314+
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
297315

298-
variance = BlockReduce(reduceStore).Sum(variance, blockDim.x);
299-
__syncthreads();
300-
variance2 = BlockReduce(reduceStore).Sum(variance2, blockDim.x);
301-
__syncthreads();
302-
variance3 = BlockReduce(reduceStore).Sum(variance3, blockDim.x);
316+
variance = block_variances.x;
317+
variance2 = block_variances.y;
318+
variance3 = block_variances.z;
303319

304320
__shared__ float s_w2_inv_std;
305321
__shared__ float s_w1_inv_std2;
@@ -323,8 +339,9 @@ poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
323339
float x2 = x * x;
324340
float x3 = x2 * x;
325341

326-
out[blockIdx.x * hidden_size + idx] = (scalar_t)(
327-
x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 + s_bias);
342+
out[blockIdx.x * hidden_size + idx] =
343+
(scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 +
344+
s_bias);
328345
}
329346
}
330347

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ th {
382382
| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
383383
| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
384384
| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
385+
| `MotifForCausalLM` | Motif-1-Tiny | `Motif-Technologies/Motif-2.6B`, `Motif-Technologies/Motif-2.6b-v1.1-LC`, etc. | | ✅︎ | |
385386
| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ |
386387
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ |
387388
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ |

tests/kernels/core/test_layernorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from tests.kernels.quant_utils import FP8_DTYPE
88
from tests.kernels.utils import opcheck
9-
from vllm.model_executor.layers.layernorm import RMSNorm, PolyNorm
9+
from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm
1010
from vllm.platforms import current_platform
1111

1212
DTYPES = [torch.half, torch.bfloat16, torch.float]

tests/models/registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,8 @@ def check_available_online(
258258
{"tiny": "TitanML/tiny-mixtral"}), # noqa: E501
259259
"QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501
260260
"MotifForCausalLM": _HfExamplesInfo("Motif-Technologies/Motif-2.6B",
261-
trust_remote_code=True),
261+
trust_remote_code=True,
262+
v0_only=True),
262263
"MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False),
263264
"MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"),
264265
"NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"),

tests/models/test_initialization.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ def _initialize_kv_caches_v1(self, vllm_config):
6565
_initialize_kv_caches_v1), monkeypatch.context() as m):
6666
if model_info.v0_only:
6767
m.setenv("VLLM_USE_V1", "0")
68-
if model_arch == "Phi4FlashForCausalLM":
69-
# Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend
68+
if model_arch in ("Phi4FlashForCausalLM", "MotifForCausalLM"):
69+
# Phi4FlashForCausalLM and MotifForCausalLM
70+
# only supports DIFFERENTIAL_FLASH_ATTN backend
7071
m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN")
7172
if model_arch == "GptOssForCausalLM":
7273
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU

vllm/attention/backends/differential_flash_attn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,7 @@ def forward_generate_kv_cache(
734734
window_size=self.sliding_window,
735735
alibi_slopes=self.alibi_slopes,
736736
softcap=self.logits_soft_cap,
737+
fa_version=self.vllm_flash_attn_version,
737738
)
738739
assert prefill_output.shape == output[:
739740
num_prefill_tokens].shape
@@ -755,6 +756,7 @@ def forward_generate_kv_cache(
755756
window_size=self.sliding_window,
756757
alibi_slopes=self.alibi_slopes,
757758
softcap=self.logits_soft_cap,
759+
fa_version=self.vllm_flash_attn_version,
758760
).squeeze(1)
759761
except Exception as e:
760762
logger.error("Error in PagedAttention.forward_decode: %s",
@@ -787,6 +789,7 @@ def forward_with_kv_cache_only(
787789
window_size=self.sliding_window,
788790
alibi_slopes=self.alibi_slopes,
789791
softcap=self.logits_soft_cap,
792+
fa_version=self.vllm_flash_attn_version,
790793
).squeeze(1)
791794
return output
792795

vllm/model_executor/models/motif.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from vllm.compilation.decorators import support_torch_compile
2020
from vllm.config import CacheConfig, VllmConfig
2121
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
22-
from vllm.model_executor.layers.layernorm import RMSNorm, PolyNorm
22+
from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm
2323
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
2424
QKVParallelLinear,
2525
RowParallelLinear)
@@ -72,17 +72,20 @@ def __init__(
7272
prefix=f"{prefix}.down_proj",
7373
)
7474
if hidden_act != "poly_norm":
75-
raise ValueError(f"Unsupported activation: {hidden_act}. "
76-
"Only poly_norm is supported for now.")
75+
raise NotImplementedError(f"Unsupported activation: {hidden_act}. "
76+
"Only poly_norm is supported for now.")
7777
self.act_fn = PolyNorm()
7878
self.intermediate_size = intermediate_size
79-
self.tp_size = get_tensor_model_parallel_world_size()
79+
tp_size = get_tensor_model_parallel_world_size()
80+
if hidden_act == "poly_norm" and tp_size > 1:
81+
raise NotImplementedError(
82+
"Tensor parallelism for poly_norm is not supported yet. "
83+
"Support will be added in the future.")
8084

8185
def forward(self, x):
8286
x, _ = self.gate_up_proj(x)
8387
x = self.act_fn(
84-
x[..., :self.intermediate_size //
85-
self.tp_size]) * x[..., self.intermediate_size // self.tp_size:]
88+
x[..., :self.intermediate_size]) * x[..., self.intermediate_size:]
8689
x, _ = self.down_proj(x)
8790
return x
8891

@@ -175,7 +178,7 @@ def __init__(
175178
self.lambda_k2 = nn.Parameter(
176179
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
177180
std=0.1))
178-
self.subln = RMSNorm(2 * self.head_dim, eps=1e-5)
181+
self.subln = RMSNorm(2 * self.head_dim, eps=config.attn_rms_norm_eps)
179182

180183
params = {
181184
'differential_flash_attention_config': {

0 commit comments

Comments
 (0)