Skip to content

Commit 86733e3

Browse files
committed
add comments
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
1 parent 750316b commit 86733e3

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

csrc/quantization/fp4/nvfp4_experts_quant.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,8 @@ cvt_fp16_to_fp4(
281281
}
282282
}
283283
} else {
284+
// Load input offsets into registers first, then do the computation.
285+
// Local array size set to 17 because of register limit.
284286
uint32_t local_offsets[17];
285287
for (int chunk_start = 0; chunk_start < n_experts; chunk_start += 16) {
286288
*reinterpret_cast<int4*>(local_offsets) =
@@ -350,6 +352,9 @@ cvt_fp16_to_fp4(
350352
"Vec size is not matched.");
351353
extern __shared__ uint32_t shared_input_offsets[];
352354

355+
// Load input offsets into shared memory.
356+
// If n_experts is larger than 4, use vectorized int4 to save instructions.
357+
// If n_experts is smaller than 4, read directly.
353358
if constexpr (SMALL_NUM_EXPERTS) {
354359
for (int i = threadIdx.x; i < n_experts + 1; i += blockDim.x) {
355360
shared_input_offsets[i] = input_offset_by_experts[i];

0 commit comments

Comments
 (0)