@@ -16,24 +16,33 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
16
16
f(in_T, out_T, W_T, narrow, 512 ) \
17
17
f(in_T, out_T, W_T, narrow, 640 ) \
18
18
f(in_T, out_T, W_T, narrow, 768 ) \
19
+ f(in_T, out_T, W_T, narrow, 896 ) \
19
20
f(in_T, out_T, W_T, narrow, 1024 ) \
20
21
f(in_T, out_T, W_T, narrow, 1152 ) \
22
+ f(in_T, out_T, W_T, narrow, 1216 ) \
21
23
f(in_T, out_T, W_T, narrow, 1280 ) \
22
24
f(in_T, out_T, W_T, narrow, 1536 ) \
23
25
f(in_T, out_T, W_T, narrow, 1664 ) \
24
26
f(in_T, out_T, W_T, narrow, 1728 ) \
25
27
f(in_T, out_T, W_T, narrow, 1792 ) \
26
28
f(in_T, out_T, W_T, narrow, 2048 ) \
29
+ f(in_T, out_T, W_T, narrow, 2240 ) \
27
30
f(in_T, out_T, W_T, narrow, 2304 ) \
31
+ f(in_T, out_T, W_T, narrow, 2368 ) \
32
+ f(in_T, out_T, W_T, narrow, 2432 ) \
28
33
f(in_T, out_T, W_T, narrow, 2560 ) \
29
34
f(in_T, out_T, W_T, narrow, 2752 ) \
30
35
f(in_T, out_T, W_T, narrow, 2816 ) \
31
36
f(in_T, out_T, W_T, narrow, 3072 ) \
32
37
f(in_T, out_T, W_T, narrow, 3328 ) \
33
38
f(in_T, out_T, W_T, narrow, 3456 ) \
34
39
f(in_T, out_T, W_T, narrow, 3584 ) \
40
+ f(in_T, out_T, W_T, narrow, 3712 ) \
35
41
f(in_T, out_T, W_T, narrow, 4096 ) \
42
+ f(in_T, out_T, W_T, narrow, 4480 ) \
36
43
f(in_T, out_T, W_T, narrow, 4608 ) \
44
+ f(in_T, out_T, W_T, narrow, 4736 ) \
45
+ f(in_T, out_T, W_T, narrow, 4864 ) \
37
46
f(in_T, out_T, W_T, narrow, 5120 ) \
38
47
f(in_T, out_T, W_T, narrow, 5504 ) \
39
48
f(in_T, out_T, W_T, narrow, 5632 ) \
@@ -43,24 +52,32 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
43
52
f(in_T, out_T, W_T, narrow, 6848 ) \
44
53
f(in_T, out_T, W_T, narrow, 6912 ) \
45
54
f(in_T, out_T, W_T, narrow, 7168 ) \
55
+ f(in_T, out_T, W_T, narrow, 7424 ) \
46
56
f(in_T, out_T, W_T, narrow, 8192 ) \
57
+ f(in_T, out_T, W_T, narrow, 8960 ) \
47
58
f(in_T, out_T, W_T, narrow, 9216 ) \
59
+ f(in_T, out_T, W_T, narrow, 9472 ) \
48
60
f(in_T, out_T, W_T, narrow, 10240 ) \
49
61
f(in_T, out_T, W_T, narrow, 11008 ) \
50
62
f(in_T, out_T, W_T, narrow, 11264 ) \
51
63
f(in_T, out_T, W_T, narrow, 12288 ) \
52
64
f(in_T, out_T, W_T, narrow, 13696 ) \
53
65
f(in_T, out_T, W_T, narrow, 13824 ) \
54
66
f(in_T, out_T, W_T, narrow, 14336 ) \
67
+ f(in_T, out_T, W_T, narrow, 14784 ) \
68
+ f(in_T, out_T, W_T, narrow, 14848 ) \
55
69
f(in_T, out_T, W_T, narrow, 15360 ) \
56
70
f(in_T, out_T, W_T, narrow, 16384 ) \
71
+ f(in_T, out_T, W_T, narrow, 18944 ) \
57
72
f(in_T, out_T, W_T, narrow, 20480 ) \
58
73
f(in_T, out_T, W_T, narrow, 22016 ) \
59
74
f(in_T, out_T, W_T, narrow, 22528 ) \
60
75
f(in_T, out_T, W_T, narrow, 24576 ) \
61
76
f(in_T, out_T, W_T, narrow, 27392 ) \
62
77
f(in_T, out_T, W_T, narrow, 27648 ) \
63
78
f(in_T, out_T, W_T, narrow, 28672 ) \
79
+ f(in_T, out_T, W_T, narrow, 29568 ) \
80
+ f(in_T, out_T, W_T, narrow, 29696 ) \
64
81
f(in_T, out_T, W_T, narrow, 32000 ) \
65
82
f(in_T, out_T, W_T, narrow, 32256 ) \
66
83
f(in_T, out_T, W_T, narrow, 32512 ) \
@@ -85,34 +102,43 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
85
102
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
86
103
// and vllm/tests/lora/test_punica.py
87
104
88
- // Used for defining kernels going from the variety of
105
+ // Used for defining kernels going from the variety of
89
106
// dim in to the narrow dim out
90
- // Using it for the fully sharded column
107
+ // Using it for the fully sharded column
91
108
// parallel LoRA A which splits the rank dim
92
109
#define FOR_INST_BGMV_NARROW (f, in_T, out_T, W_T, narrow ) \
93
110
f (in_T, out_T, W_T, 128 , narrow) \
94
111
f(in_T, out_T, W_T, 256 , narrow) \
95
112
f(in_T, out_T, W_T, 512 , narrow) \
96
113
f(in_T, out_T, W_T, 640 , narrow) \
97
114
f(in_T, out_T, W_T, 768 , narrow) \
115
+ f(in_T, out_T, W_T, 896 , narrow) \
98
116
f(in_T, out_T, W_T, 1024 , narrow) \
99
117
f(in_T, out_T, W_T, 1152 , narrow) \
118
+ f(in_T, out_T, W_T, 1216 , narrow) \
100
119
f(in_T, out_T, W_T, 1280 , narrow) \
101
120
f(in_T, out_T, W_T, 1536 , narrow) \
102
121
f(in_T, out_T, W_T, 1664 , narrow) \
103
122
f(in_T, out_T, W_T, 1728 , narrow) \
104
123
f(in_T, out_T, W_T, 1792 , narrow) \
105
124
f(in_T, out_T, W_T, 2048 , narrow) \
125
+ f(in_T, out_T, W_T, 2240 , narrow) \
106
126
f(in_T, out_T, W_T, 2304 , narrow) \
127
+ f(in_T, out_T, W_T, 2368 , narrow) \
128
+ f(in_T, out_T, W_T, 2432 , narrow) \
107
129
f(in_T, out_T, W_T, 2560 , narrow) \
108
130
f(in_T, out_T, W_T, 2752 , narrow) \
109
131
f(in_T, out_T, W_T, 2816 , narrow) \
110
132
f(in_T, out_T, W_T, 3072 , narrow) \
111
133
f(in_T, out_T, W_T, 3328 , narrow) \
112
134
f(in_T, out_T, W_T, 3456 , narrow) \
113
135
f(in_T, out_T, W_T, 3584 , narrow) \
136
+ f(in_T, out_T, W_T, 3712 , narrow) \
114
137
f(in_T, out_T, W_T, 4096 , narrow) \
138
+ f(in_T, out_T, W_T, 4480 , narrow) \
115
139
f(in_T, out_T, W_T, 4608 , narrow) \
140
+ f(in_T, out_T, W_T, 4736 , narrow) \
141
+ f(in_T, out_T, W_T, 4864 , narrow) \
116
142
f(in_T, out_T, W_T, 5120 , narrow) \
117
143
f(in_T, out_T, W_T, 5504 , narrow) \
118
144
f(in_T, out_T, W_T, 5632 , narrow) \
@@ -122,24 +148,32 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
122
148
f(in_T, out_T, W_T, 6848 , narrow) \
123
149
f(in_T, out_T, W_T, 6912 , narrow) \
124
150
f(in_T, out_T, W_T, 7168 , narrow) \
151
+ f(in_T, out_T, W_T, 7424 , narrow) \
125
152
f(in_T, out_T, W_T, 8192 , narrow) \
153
+ f(in_T, out_T, W_T, 8960 , narrow) \
126
154
f(in_T, out_T, W_T, 9216 , narrow) \
155
+ f(in_T, out_T, W_T, 9472 , narrow) \
127
156
f(in_T, out_T, W_T, 10240 , narrow) \
128
157
f(in_T, out_T, W_T, 11008 , narrow) \
129
158
f(in_T, out_T, W_T, 11264 , narrow) \
130
159
f(in_T, out_T, W_T, 12288 , narrow) \
131
160
f(in_T, out_T, W_T, 13696 , narrow) \
132
161
f(in_T, out_T, W_T, 13824 , narrow) \
133
162
f(in_T, out_T, W_T, 14336 , narrow) \
163
+ f(in_T, out_T, W_T, 14784 , narrow) \
164
+ f(in_T, out_T, W_T, 14848 , narrow) \
134
165
f(in_T, out_T, W_T, 15360 , narrow) \
135
166
f(in_T, out_T, W_T, 16384 , narrow) \
167
+ f(in_T, out_T, W_T, 18944 , narrow) \
136
168
f(in_T, out_T, W_T, 20480 , narrow) \
137
169
f(in_T, out_T, W_T, 22016 , narrow) \
138
170
f(in_T, out_T, W_T, 22528 , narrow) \
139
171
f(in_T, out_T, W_T, 24576 , narrow) \
140
172
f(in_T, out_T, W_T, 27392 , narrow) \
141
173
f(in_T, out_T, W_T, 27648 , narrow) \
142
174
f(in_T, out_T, W_T, 28672 , narrow) \
175
+ f(in_T, out_T, W_T, 29568 , narrow) \
176
+ f(in_T, out_T, W_T, 29696 , narrow) \
143
177
f(in_T, out_T, W_T, 32000 , narrow) \
144
178
f(in_T, out_T, W_T, 32256 , narrow) \
145
179
f(in_T, out_T, W_T, 32512 , narrow) \
0 commit comments