Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 1996acf

Browse files
jinzhen-linRobert Shaw
authored andcommitted
[Kernel] Add punica dimension for Qwen2 LoRA (vllm-project#5441)
1 parent b05443a commit 1996acf

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

csrc/punica/bgmv/bgmv_config.h

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,33 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
1616
f(in_T, out_T, W_T, narrow, 512) \
1717
f(in_T, out_T, W_T, narrow, 640) \
1818
f(in_T, out_T, W_T, narrow, 768) \
19+
f(in_T, out_T, W_T, narrow, 896) \
1920
f(in_T, out_T, W_T, narrow, 1024) \
2021
f(in_T, out_T, W_T, narrow, 1152) \
22+
f(in_T, out_T, W_T, narrow, 1216) \
2123
f(in_T, out_T, W_T, narrow, 1280) \
2224
f(in_T, out_T, W_T, narrow, 1536) \
2325
f(in_T, out_T, W_T, narrow, 1664) \
2426
f(in_T, out_T, W_T, narrow, 1728) \
2527
f(in_T, out_T, W_T, narrow, 1792) \
2628
f(in_T, out_T, W_T, narrow, 2048) \
29+
f(in_T, out_T, W_T, narrow, 2240) \
2730
f(in_T, out_T, W_T, narrow, 2304) \
31+
f(in_T, out_T, W_T, narrow, 2368) \
32+
f(in_T, out_T, W_T, narrow, 2432) \
2833
f(in_T, out_T, W_T, narrow, 2560) \
2934
f(in_T, out_T, W_T, narrow, 2752) \
3035
f(in_T, out_T, W_T, narrow, 2816) \
3136
f(in_T, out_T, W_T, narrow, 3072) \
3237
f(in_T, out_T, W_T, narrow, 3328) \
3338
f(in_T, out_T, W_T, narrow, 3456) \
3439
f(in_T, out_T, W_T, narrow, 3584) \
40+
f(in_T, out_T, W_T, narrow, 3712) \
3541
f(in_T, out_T, W_T, narrow, 4096) \
42+
f(in_T, out_T, W_T, narrow, 4480) \
3643
f(in_T, out_T, W_T, narrow, 4608) \
44+
f(in_T, out_T, W_T, narrow, 4736) \
45+
f(in_T, out_T, W_T, narrow, 4864) \
3746
f(in_T, out_T, W_T, narrow, 5120) \
3847
f(in_T, out_T, W_T, narrow, 5504) \
3948
f(in_T, out_T, W_T, narrow, 5632) \
@@ -43,24 +52,32 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
4352
f(in_T, out_T, W_T, narrow, 6848) \
4453
f(in_T, out_T, W_T, narrow, 6912) \
4554
f(in_T, out_T, W_T, narrow, 7168) \
55+
f(in_T, out_T, W_T, narrow, 7424) \
4656
f(in_T, out_T, W_T, narrow, 8192) \
57+
f(in_T, out_T, W_T, narrow, 8960) \
4758
f(in_T, out_T, W_T, narrow, 9216) \
59+
f(in_T, out_T, W_T, narrow, 9472) \
4860
f(in_T, out_T, W_T, narrow, 10240) \
4961
f(in_T, out_T, W_T, narrow, 11008) \
5062
f(in_T, out_T, W_T, narrow, 11264) \
5163
f(in_T, out_T, W_T, narrow, 12288) \
5264
f(in_T, out_T, W_T, narrow, 13696) \
5365
f(in_T, out_T, W_T, narrow, 13824) \
5466
f(in_T, out_T, W_T, narrow, 14336) \
67+
f(in_T, out_T, W_T, narrow, 14784) \
68+
f(in_T, out_T, W_T, narrow, 14848) \
5569
f(in_T, out_T, W_T, narrow, 15360) \
5670
f(in_T, out_T, W_T, narrow, 16384) \
71+
f(in_T, out_T, W_T, narrow, 18944) \
5772
f(in_T, out_T, W_T, narrow, 20480) \
5873
f(in_T, out_T, W_T, narrow, 22016) \
5974
f(in_T, out_T, W_T, narrow, 22528) \
6075
f(in_T, out_T, W_T, narrow, 24576) \
6176
f(in_T, out_T, W_T, narrow, 27392) \
6277
f(in_T, out_T, W_T, narrow, 27648) \
6378
f(in_T, out_T, W_T, narrow, 28672) \
79+
f(in_T, out_T, W_T, narrow, 29568) \
80+
f(in_T, out_T, W_T, narrow, 29696) \
6481
f(in_T, out_T, W_T, narrow, 32000) \
6582
f(in_T, out_T, W_T, narrow, 32256) \
6683
f(in_T, out_T, W_T, narrow, 32512) \
@@ -85,34 +102,43 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
85102
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
86103
// and vllm/tests/lora/test_punica.py
87104

88-
// Used for defining kernels going from the variety of
105+
// Used for defining kernels going from the variety of
89106
// dim in to the narrow dim out
90-
// Using it for the fully sharded column
107+
// Using it for the fully sharded column
91108
// parallel LoRA A which splits the rank dim
92109
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
93110
f(in_T, out_T, W_T, 128, narrow) \
94111
f(in_T, out_T, W_T, 256, narrow) \
95112
f(in_T, out_T, W_T, 512, narrow) \
96113
f(in_T, out_T, W_T, 640, narrow) \
97114
f(in_T, out_T, W_T, 768, narrow) \
115+
f(in_T, out_T, W_T, 896, narrow) \
98116
f(in_T, out_T, W_T, 1024, narrow) \
99117
f(in_T, out_T, W_T, 1152, narrow) \
118+
f(in_T, out_T, W_T, 1216, narrow) \
100119
f(in_T, out_T, W_T, 1280, narrow) \
101120
f(in_T, out_T, W_T, 1536, narrow) \
102121
f(in_T, out_T, W_T, 1664, narrow) \
103122
f(in_T, out_T, W_T, 1728, narrow) \
104123
f(in_T, out_T, W_T, 1792, narrow) \
105124
f(in_T, out_T, W_T, 2048, narrow) \
125+
f(in_T, out_T, W_T, 2240, narrow) \
106126
f(in_T, out_T, W_T, 2304, narrow) \
127+
f(in_T, out_T, W_T, 2368, narrow) \
128+
f(in_T, out_T, W_T, 2432, narrow) \
107129
f(in_T, out_T, W_T, 2560, narrow) \
108130
f(in_T, out_T, W_T, 2752, narrow) \
109131
f(in_T, out_T, W_T, 2816, narrow) \
110132
f(in_T, out_T, W_T, 3072, narrow) \
111133
f(in_T, out_T, W_T, 3328, narrow) \
112134
f(in_T, out_T, W_T, 3456, narrow) \
113135
f(in_T, out_T, W_T, 3584, narrow) \
136+
f(in_T, out_T, W_T, 3712, narrow) \
114137
f(in_T, out_T, W_T, 4096, narrow) \
138+
f(in_T, out_T, W_T, 4480, narrow) \
115139
f(in_T, out_T, W_T, 4608, narrow) \
140+
f(in_T, out_T, W_T, 4736, narrow) \
141+
f(in_T, out_T, W_T, 4864, narrow) \
116142
f(in_T, out_T, W_T, 5120, narrow) \
117143
f(in_T, out_T, W_T, 5504, narrow) \
118144
f(in_T, out_T, W_T, 5632, narrow) \
@@ -122,24 +148,32 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
122148
f(in_T, out_T, W_T, 6848, narrow) \
123149
f(in_T, out_T, W_T, 6912, narrow) \
124150
f(in_T, out_T, W_T, 7168, narrow) \
151+
f(in_T, out_T, W_T, 7424, narrow) \
125152
f(in_T, out_T, W_T, 8192, narrow) \
153+
f(in_T, out_T, W_T, 8960, narrow) \
126154
f(in_T, out_T, W_T, 9216, narrow) \
155+
f(in_T, out_T, W_T, 9472, narrow) \
127156
f(in_T, out_T, W_T, 10240, narrow) \
128157
f(in_T, out_T, W_T, 11008, narrow) \
129158
f(in_T, out_T, W_T, 11264, narrow) \
130159
f(in_T, out_T, W_T, 12288, narrow) \
131160
f(in_T, out_T, W_T, 13696, narrow) \
132161
f(in_T, out_T, W_T, 13824, narrow) \
133162
f(in_T, out_T, W_T, 14336, narrow) \
163+
f(in_T, out_T, W_T, 14784, narrow) \
164+
f(in_T, out_T, W_T, 14848, narrow) \
134165
f(in_T, out_T, W_T, 15360, narrow) \
135166
f(in_T, out_T, W_T, 16384, narrow) \
167+
f(in_T, out_T, W_T, 18944, narrow) \
136168
f(in_T, out_T, W_T, 20480, narrow) \
137169
f(in_T, out_T, W_T, 22016, narrow) \
138170
f(in_T, out_T, W_T, 22528, narrow) \
139171
f(in_T, out_T, W_T, 24576, narrow) \
140172
f(in_T, out_T, W_T, 27392, narrow) \
141173
f(in_T, out_T, W_T, 27648, narrow) \
142174
f(in_T, out_T, W_T, 28672, narrow) \
175+
f(in_T, out_T, W_T, 29568, narrow) \
176+
f(in_T, out_T, W_T, 29696, narrow) \
143177
f(in_T, out_T, W_T, 32000, narrow) \
144178
f(in_T, out_T, W_T, 32256, narrow) \
145179
f(in_T, out_T, W_T, 32512, narrow) \

tests/lora/test_punica.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,21 +54,30 @@ def _lora_ref_impl(
5454
128,
5555
256,
5656
512,
57+
896,
5758
1024,
5859
1152,
60+
1216,
5961
1280,
6062
1536,
6163
1664,
6264
2048,
65+
2240,
6366
2304,
67+
2368,
68+
2432,
6469
2560,
6570
2752,
6671
3072,
6772
3328,
6873
3456,
6974
3584,
75+
3712,
7076
4096,
77+
4480,
7178
4608,
79+
4736,
80+
4864,
7281
5120,
7382
5504,
7483
5632,
@@ -78,19 +87,27 @@ def _lora_ref_impl(
7887
6848,
7988
6912,
8089
7168,
90+
7424,
8191
8192,
92+
8960,
8293
9216,
94+
9472,
8395
10240,
8496
11008,
8597
11264,
8698
13824,
8799
14336,
100+
14784,
101+
14848,
88102
15360,
103+
18944,
89104
22016,
90105
22528,
91106
24576,
92107
27392,
93108
27648,
109+
29568,
110+
29696,
94111
32000,
95112
32256,
96113
32512,

0 commit comments

Comments
 (0)