Skip to content

Commit 9df016f

Browse files
authored
[Inference] Fix quant bits order (hpcaitech#5681)
1 parent f799631 commit 9df016f

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

extensions/csrc/funcs/cast_functor.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
390390
static_cast<uint16_t>(CastFunctor<float, uint8_t>()(val.x));
391391
uint16_t tmp2 =
392392
static_cast<uint16_t>(CastFunctor<float, uint8_t>()(val.y));
393-
uint16_t res = (tmp1 << 8U) | tmp2;
393+
uint16_t res = (tmp2 << 8U) | tmp1;
394394
return res;
395395
}))
396396

@@ -401,8 +401,8 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint32_t, DEVICE, STMTS_WRAPPER({
401401
b = CastFunctor<float, uint8_t>()(val.y);
402402
c = CastFunctor<float, uint8_t>()(val.z);
403403
d = CastFunctor<float, uint8_t>()(val.w);
404-
return (a << 24U) | (b << 16U) |
405-
(c << 8U) | d;
404+
return (d << 24U) | (c << 16U) |
405+
(b << 8U) | a;
406406
}))
407407

408408
// fp8x4 -> float4_
@@ -458,7 +458,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
458458
static_cast<uint16_t>(CastFunctor<__nv_bfloat16, uint8_t>()(val.x));
459459
uint16_t b =
460460
static_cast<uint16_t>(CastFunctor<__nv_bfloat16, uint8_t>()(val.y));
461-
return (a << 8U) | b;
461+
return (b << 8U) | a;
462462
}))
463463

464464
// bf164 -> fp8x4

0 commit comments

Comments
 (0)