Skip to content

Commit b5ceeed

Browse files
authored
Relax FP8 TP requirement (InternLM#3697)
1 parent bc87b22 commit b5ceeed

File tree

9 files changed

+104
-74
lines changed

9 files changed

+104
-74
lines changed

lmdeploy/turbomind/deploy/module.py

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -100,26 +100,18 @@ def __init__(self, model: BaseOutputModel):
100100
self.inter_size = model.model_config.inter_size
101101
self.group_size = max(1, model.model_config.group_size)
102102

103-
def _export(self,
104-
inter_size: int,
105-
fmt: str,
106-
idx: int,
107-
w123,
108-
kind: str,
109-
pack_fn,
110-
apply_gs=False,
111-
block_size=1,
112-
**kwargs):
103+
def _export(self, inter_size: int, fmt: str, idx: int, w123, kind: str, pack_fn, apply_gs=[], **kwargs):
113104
is_lora_a, is_lora_b = get_lora_flags(kind)
114105
w1, w2, w3 = map(transpose, w123)
115106

116-
# TODO: handle padding for block_size != 1
117-
if not is_lora_a and block_size == 1:
118-
w1 = pad_out_dims(w1, inter_size)
119-
w3 = pad_out_dims(w3, inter_size)
120-
if not is_lora_b and block_size == 1:
121-
group_size = self.group_size if apply_gs else 1
122-
w2 = pad_in_dims(w2, inter_size // group_size)
107+
gs1 = self.group_size if 'w1' in apply_gs else 1
108+
w1 = pad_out_dims(w1, inter_size // gs1)
109+
110+
gs3 = self.group_size if 'w3' in apply_gs else 1
111+
w3 = pad_out_dims(w3, inter_size // gs3)
112+
113+
gs2 = self.group_size if 'w2' in apply_gs else 1
114+
w2 = pad_in_dims(w2, inter_size // gs2)
123115

124116
w1, w2, w3 = map(pack_fn, (w1, w2, w3))
125117
self.model.save_split(w1, fmt.format(idx, 'w1', kind), split_dim=-1, split_num=self.tp, copy=is_lora_a)
@@ -180,54 +172,63 @@ def __init__(self, model: BaseOutputModel):
180172
self.head_dim = model.model_config.size_per_head
181173
self.attn_bias = model.model_config.attn_bias
182174
self.qk_norm = model.model_config.qk_norm
175+
self.group_size = max(1, model.model_config.group_size)
183176

184-
def _reorder_and_merge(self, qkvo, block_size):
177+
def _reorder_and_merge(self, qkvo, gs: int):
185178
q, k, v, o = qkvo
186179
# reorder output dim for tm's rotary embedding layout
187180
if self.model.permute_qk:
188-
if block_size == 1:
181+
if gs == 1:
189182
q = permute_v2(q, self.head_dim)
190183
k = permute_v2(k, self.head_dim)
191184
else:
192-
assert block_size % self.head_dim == 0
185+
assert gs % self.head_dim == 0
193186
qkv = merge_qkv_v2(q, k, v, self.tp)
194187
# zero bias for `wo` when `w_qkv` has bias but `wo` doesn't
195188
if o is None and q.dim() == 1:
196189
o = torch.zeros_like(q)
197190
return qkv, o
198191

199-
def _repeat_kv(self, qkvo, kind: str):
192+
def _repeat_kv(self, qkvo, gs: int, kind: str):
200193
"""Replicate kv."""
201194
q, k, v, o = qkvo
202-
head_dim = self.model.model_config.size_per_head
195+
head_dim = self.model.model_config.size_per_head // gs
196+
kv_head_num = self.model.model_config.kv_head_num // self.model.repeat_kv
203197
hidden_dim = self.model.model_config.hidden_units
204198

205199
def _repeat(x):
206-
dim = hidden_dim if kind != 'bias' else 1
207-
x = x.reshape(dim, -1, head_dim)
208-
x = x.repeat(1, 1, self.model.repeat_kv)
209-
x = x.reshape(dim, -1)
200+
n = self.model.repeat_kv
201+
202+
x = x.reshape(-1, kv_head_num, head_dim)
203+
x = x.repeat(1, 1, n)
204+
x = x.reshape(-1, kv_head_num * n * head_dim)
205+
210206
return x
211207

212208
k, v = map(_repeat, (k, v))
209+
213210
if kind == 'bias':
214211
if o is None:
215212
o = torch.zeros(hidden_dim, dtype=q.dtype, device=q.device)
216213
q, k, v, o = map(torch.squeeze, (q, k, v, o))
217214

218215
return (q, k, v, o)
219216

220-
def _export(self, idx: int, qkvo, kind: str, pack_fn, block_size=1, **kwargs):
217+
def _export(self, idx: int, qkvo, kind: str, pack_fn, apply_gs=[], **kwargs):
221218
if all(x is None for x in qkvo):
222219
return
223220
is_lora_a, is_lora_b = get_lora_flags(kind)
224-
if is_lora_a:
225-
qkv, o = map(transpose, qkvo)
226-
else:
227-
qkvo = tuple(map(transpose, qkvo))
228-
if self.model.repeat_kv:
229-
qkvo = self._repeat_kv(qkvo, kind)
230-
qkv, o = self._reorder_and_merge(qkvo, block_size)
221+
assert not (is_lora_a or is_lora_b)
222+
223+
qkvo = tuple(map(transpose, qkvo))
224+
225+
gs = self.group_size if ('w1' in apply_gs) else 1
226+
227+
if self.model.repeat_kv:
228+
qkvo = self._repeat_kv(qkvo, gs, kind)
229+
230+
qkv, o = self._reorder_and_merge(qkvo, gs)
231+
231232
self.model.save_split(pack_fn(qkv),
232233
self._attn.format(idx, 'w_qkv', kind),
233234
split_dim=-1,

lmdeploy/turbomind/deploy/parameter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,16 @@ class QuantWeightOnly(Parameter):
5656

5757
def __call__(self, f, g, i):
5858
f(i, g('qweight'), 'qweight', pack_u4_row)
59-
f(i, g('scales'), 'scales', to_half, apply_gs=True)
60-
f(i, g('qzeros'), 'zeros', to_half, apply_gs=True)
59+
f(i, g('scales'), 'scales', to_half, apply_gs=['w2'])
60+
f(i, g('qzeros'), 'zeros', to_half, apply_gs=['w2'])
6161

6262

6363
class WeightScaleInv(Parameter):
6464
KEYS = '.weight_scale_inv', '.weight'
6565

6666
# TODO: flag any operations crossing the quant blocks as illegal
6767
def __call__(self, f, g, i):
68-
f(i, g('weight_scale_inv'), 'scales', to_float, block_size=128)
68+
f(i, g('weight_scale_inv'), 'scales', to_float, apply_gs=['w1', 'w3', 'w2'])
6969
f(i, g('weight'), 'weight', identity)
7070

7171

src/turbomind/core/tensor.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@ class Tensor {
2222
buffer_ = Buffer(layout_.cosize(), dtype, alloc);
2323
}
2424

25-
Tensor(Buffer buffer, Layout layout): layout_{std::move(layout)}, buffer_{std::move(buffer)}
26-
{
27-
TM_CHECK_LE(layout_.cosize(), buffer_.size());
28-
}
25+
Tensor(Buffer buffer, Layout layout): layout_{std::move(layout)}, buffer_{buffer.slice(0, layout_.cosize())} {}
2926

3027
Tensor(Buffer buffer): layout_{buffer.size()}, buffer_{buffer} {}
3128

@@ -204,11 +201,16 @@ class Tensor {
204201
Buffer buffer_;
205202
};
206203

207-
static Tensor empty_like(const Tensor& tensor, std::optional<Device> device = {})
204+
inline Tensor empty_like(const Tensor& tensor, std::optional<Device> device = {})
208205
{
209206
return Tensor{tensor.layout(), tensor.dtype(), device ? *device : tensor.device()};
210207
}
211208

209+
inline Tensor empty_like(const Tensor& tensor, DataType dtype)
210+
{
211+
return Tensor{tensor.layout(), dtype, tensor.device()};
212+
}
213+
212214
void Copy(const Tensor& src, Ref<Tensor> dst_, const Stream& stream);
213215

214216
void Copy(const Tensor& src, Ref<Tensor> dst_);

src/turbomind/kernels/gemm/test/testbed_v2.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,22 @@ class Testbed_v2 {
141141
rng_.NormalFloat(b_, 1., 1.);
142142

143143
if (Ta == kFloat8_e4m3) {
144-
QuantizeSymmBlock(a_q_, a_s_, a_, stream);
145-
DequantizeSymmBlock(a_f_, a_q_, a_s_, stream);
144+
if (expert_num_ == 0) {
145+
QuantizeSymmBlock(a_q_, a_s_, a_, stream);
146+
DequantizeSymmBlock(a_f_, a_q_, a_s_, stream);
147+
}
148+
else {
149+
a_q_ = empty_like(a_, kFloat8_e4m3);
150+
a_f_ = empty_like(a_);
151+
const int m_s = cdiv(M, 128);
152+
a_s_ = Tensor_<float>({m_s * expert_num_, cdiv(K, 128)}, kDEVICE);
153+
for (int i = 0; i < expert_num_; ++i) {
154+
auto a_s = a_s_.slice(i * m_s, m_s);
155+
QuantizeSymmBlock(a_q_.slice(i * M, M), a_s, a_.slice(i * M, M), stream);
156+
DequantizeSymmBlock(a_f_.slice(i * M, M), a_q_.slice(i * M, M), a_s, stream);
157+
}
158+
}
159+
146160
a_q_desc_ = {a_q_.dtype(), kRowMajor, M, K, (int)a_q_.stride(0)};
147161
u_desc_ = {a_s_.dtype(), kRowMajor, (int)a_s_.shape(0), (int)a_s_.shape(1), (int)a_s_.stride(0)};
148162
tie(a_x_, a_desc_x_) = std::make_tuple(&a_q_, &a_q_desc_);
@@ -221,8 +235,8 @@ class Testbed_v2 {
221235
for (int i = 0; i < expert_num_; ++i) {
222236
a_ptrs[i] = reinterpret_cast<uint64_t>(a_x_->slice(m_offset[i]).raw_data());
223237
if (a_s_) {
224-
TM_CHECK(m_offset[i] % 128 == 0);
225-
u_ptrs[i] = reinterpret_cast<uint64_t>(a_s_.slice(m_offset[i] / 128).raw_data());
238+
const int m_s = cdiv(M, 128);
239+
u_ptrs[i] = reinterpret_cast<uint64_t>(a_s_.slice(i * m_s).raw_data());
226240
}
227241
}
228242

src/turbomind/kernels/norm/rms_norm.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,15 @@ __global__ void RMSNorm(T* dst,
8686

8787
void invokeRMSNorm(Tensor& out, const Tensor& x, const Tensor& w, float eps, cudaStream_t st)
8888
{
89+
if (x.size() == 0) {
90+
return;
91+
}
92+
8993
TM_CHECK(x.ndim() == 2);
9094
TM_CHECK(out.shape() == x.shape());
9195
TM_CHECK(out.dtype() == x.dtype());
9296
TM_CHECK(w.dtype() == x.dtype() && w.shape(-1) == x.shape(-1));
9397

94-
if (x.size() == 0) {
95-
return;
96-
}
97-
9898
auto invoke = [&](auto t) {
9999
using T = decltype(t);
100100

src/turbomind/kernels/quantization.cu

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,17 @@ __global__ void quant_symm_row(
2222
#if TURBOMIND_ARCH_SM90
2323
static_assert(group_size % vec_size == 0);
2424
constexpr int threads = group_size / vec_size;
25+
const int dim1 = round_up(dim, WARP_SIZE * vec_size);
2526
for (int ti = blockIdx.x; ti < num; ti += gridDim.x) {
26-
for (int di = threadIdx.x * vec_size; di < dim; di += blockDim.x * vec_size) {
27-
Array<T, vec_size> vec;
28-
Ldg(vec, src + ti * src_ld + di);
27+
for (int di = threadIdx.x * vec_size; di < dim1; di += blockDim.x * vec_size) {
28+
Array<T, vec_size> vec{};
29+
if (di < dim) {
30+
Ldg(vec, src + ti * src_ld + di);
31+
}
2932
auto absmax = fmaxf(static_cast<Tscale>(find_absmax<threads>(vec)), 1e-8f);
3033
const Tscale scale = absmax / qmax;
3134
const Tscale inv_scale = qmax / absmax;
32-
if (threadIdx.x % threads == 0) {
35+
if (threadIdx.x % threads == 0 && di < dim) {
3336
// column-major
3437
scales[(di / group_size) * scales_ld + ti] = scale;
3538
}
@@ -38,7 +41,9 @@ __global__ void quant_symm_row(
3841
for (int c = 0; c < vec_size; ++c) {
3942
tmp[c] = Tout(static_cast<Tscale>(vec[c]) * inv_scale);
4043
}
41-
Store(out + ti * out_ld + di, tmp);
44+
if (di < dim) {
45+
Store(out + ti * out_ld + di, tmp);
46+
}
4247
}
4348
}
4449
#endif
@@ -69,11 +74,13 @@ void QuantizeSymm(Tensor& out, Tensor& scale, const Tensor& src, cudaStream_t st
6974

7075
const int aligned_num = round_up<int>(num, alignment);
7176

77+
const int s_dim = cdiv<ssize_t>(dim, group_size);
78+
7279
if (!scale) {
73-
scale = Tensor_<Tscale>({{dim / group_size, num}, {aligned_num, 1}}, kDEVICE);
80+
scale = Tensor_<Tscale>({{s_dim, num}, {aligned_num, 1}}, kDEVICE);
7481
}
7582
else {
76-
TM_CHECK(std::make_tuple(dim / group_size, num) == scale.shapes(0, 1));
83+
TM_CHECK(std::make_tuple(s_dim, num) == scale.shapes(0, 1));
7784
TM_CHECK(scale.stride(1) == 1);
7885
TM_CHECK(scale.stride(0) % alignment == 0);
7986
}
@@ -159,17 +166,17 @@ __global__ void quant_symm_block(Tout* out, Tscale* scales, const T* src, Tscale
159166
__shared__ typename BlockReduce::TempStorage temp_storage;
160167
__shared__ T shared_inv_scale;
161168

162-
const int ti = blockIdx.x * block_size;
163-
const int di = blockIdx.y * block_size;
164-
const int col = threadIdx.x % threads;
165169
const int row = threadIdx.x / threads;
170+
const int col = threadIdx.x % threads;
171+
const int ti = blockIdx.x * block_size;
172+
const int di = blockIdx.y * block_size + col * vec_size;
166173

167174
T absmax{};
168175
Array<T, vec_size> xs[S]{};
169176
PRAGMA_UNROLL
170177
for (int s = 0; s < S; ++s) {
171-
if (auto r = ti + s * rows + row; r < num) {
172-
Ldg(xs[s], src + (int64_t)r * dim + di + col * vec_size);
178+
if (auto r = ti + s * rows + row; r < num && di < dim) {
179+
Ldg(xs[s], src + (int64_t)r * dim + di);
173180
}
174181
PRAGMA_UNROLL
175182
for (int i = 0; i < vec_size; ++i) {
@@ -193,14 +200,14 @@ __global__ void quant_symm_block(Tout* out, Tscale* scales, const T* src, Tscale
193200
for (int i = 0; i < vec_size; ++i) {
194201
ys[s][i] = Tout(static_cast<Tscale>(xs[s][i]) * inv_scale);
195202
}
196-
if (auto r = ti + s * rows + row; r < num) {
197-
Store(out + (int64_t)r * dim + di + col * vec_size, ys[s]);
203+
if (auto r = ti + s * rows + row; r < num && di < dim) {
204+
Store(out + (int64_t)r * dim + di, ys[s]);
198205
}
199206
}
200207
#endif
201208
}
202209

203-
void QuantizeSymmBlock(Tensor& out, Tensor& scale, const Tensor& src, cudaStream_t st)
210+
void QuantizeSymmBlock(Ref<Tensor> out_, Ref<Tensor> scale_, const Tensor& src, cudaStream_t st)
204211
{
205212
TM_CHECK(src.is_contiguous());
206213
TM_CHECK_EQ(src.ndim(), 2);
@@ -220,6 +227,9 @@ void QuantizeSymmBlock(Tensor& out, Tensor& scale, const Tensor& src, cudaStream
220227
constexpr int cta_size = 1024;
221228
const dim3 grid(bnum, bdim);
222229

230+
auto& out = out_.get();
231+
auto& scale = scale_.get();
232+
223233
if (!out) {
224234
out = Tensor_<Tout>{src.layout(), kDEVICE};
225235
}
@@ -259,7 +269,7 @@ __global__ void dequant_symm_block(Tout* out, const T* src, const Tscale* scales
259269
PRAGMA_UNROLL
260270
for (int s = 0; s < S; ++s) {
261271
const auto ti = blockIdx.x * block_size + s * rows + row;
262-
if (ti < num) {
272+
if (ti < num && di < dim) {
263273
Array<T, vec_size> x;
264274
Ldg(x, src + (int64_t)ti * dim + di);
265275
Array<Tout, vec_size> y;
@@ -273,7 +283,7 @@ __global__ void dequant_symm_block(Tout* out, const T* src, const Tscale* scales
273283
#endif
274284
}
275285

276-
void DequantizeSymmBlock(Tensor& out, const Tensor& src, const Tensor& scale, cudaStream_t st)
286+
void DequantizeSymmBlock(Ref<Tensor> out_, Ref<Tensor> src_, const Tensor& scale, cudaStream_t st)
277287
{
278288
using T = fp8_e4m3_t;
279289
using Tout = bfloat16_t;
@@ -282,6 +292,9 @@ void DequantizeSymmBlock(Tensor& out, const Tensor& src, const Tensor& scale, cu
282292
constexpr int block_size = 128;
283293
constexpr int vec_size = 8;
284294

295+
auto& out = out_.get();
296+
auto& src = src_.get();
297+
285298
if (!out) {
286299
out = Tensor_<Tout>{src.layout(), kDEVICE};
287300
}

src/turbomind/kernels/quantization.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ void QuantizeSymm(Tensor& out, Tensor& scale, const Tensor& src, cudaStream_t st
66

77
void DequantizeSymm(Tensor& out, const Tensor& src, const Tensor& scale, cudaStream_t st);
88

9-
void QuantizeSymmBlock(Tensor& out, Tensor& scale, const Tensor& src, cudaStream_t st);
9+
void QuantizeSymmBlock(Ref<Tensor> out_, Ref<Tensor> scale_, const Tensor& src, cudaStream_t st);
1010

11-
void DequantizeSymmBlock(Tensor& out, const Tensor& src, const Tensor& scale, cudaStream_t st);
11+
void DequantizeSymmBlock(Ref<Tensor> out_, Ref<Tensor> src_, const Tensor& scale, cudaStream_t st);
1212

1313
} // namespace turbomind

src/turbomind/models/llama/moe_ffn_layer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ MoeFfnLayer::MoeFfnLayer(const ModelParam& model, const MoeParam& param, const E
3636

3737
h_offsets_ = {max_expert_num + 1, kCPUpinned};
3838

39-
const int max_token_num = engine.max_forward_token_num;
39+
const int max_token_num = engine.max_forward_token_num * engine.attn_dp_size;
4040
const int pad_token_num = (max_token_num + kMoeGateVecSize - 1) / kMoeGateVecSize * kMoeGateVecSize;
4141

4242
masks_ = {max_expert_num * pad_token_num, kDEVICE};

0 commit comments

Comments
 (0)