@@ -609,87 +609,72 @@ float bf16_to_f32(uint16_t bfloat16) {
609609    return  *reinterpret_cast <float *>(&val_bits);
610610}
611611
612- uint16_t  f8_e4m3_to_f16 (uint8_t  f8 ) {
613-     //  do we need to support uz?
614- 
615-     const  uint32_t  exponent_bias = 7 ;
616-     if  (f8  == 0xff ) {
617-         return  ggml_fp32_to_fp16 (-NAN);
618-     } else  if  (f8  == 0x7f ) {
619-         return  ggml_fp32_to_fp16 (NAN);
612+ uint16_t  f8_e3m4_to_f16 (uint8_t  fp8) {
613+     if  ((fp8 & 0x7F ) == 0  || (fp8 & 0x7F ) == 0x7F ) {
614+         //  +/- 0 or NaN
615+         return  static_cast <uint16_t >(fp8) << 8 ;
620616    }
617+     const  uint32_t  exponent_bias = 0x3 ;  //  2^(3-1)-1
618+     const  uint32_t  f16_bias      = 0xF ;  //  2^(5-1)-1
619+     const  int  mantissa_bits      = 4 ;
620+     const  int  mantissa_max       = 0xF ;  //  2^4-1
621621
622-     uint32_t  sign     = f8  & 0x80 ;
623-     uint32_t  exponent = (f8  & 0x78 ) >> 3 ;
624-     uint32_t  mantissa = f8  & 0x07 ;
625-     uint32_t  result   = sign << 24 ;
626-     if  (exponent == 0 ) {
627-         if  (mantissa > 0 ) {
628-             exponent = 0x7f  - exponent_bias;
629- 
630-             //  yes, 2 times
631-             if  ((mantissa & 0x04 ) == 0 ) {
632-                 mantissa &= 0x03 ;
633-                 mantissa <<= 1 ;
634-                 exponent -= 1 ;
635-             }
636-             if  ((mantissa & 0x04 ) == 0 ) {
637-                 mantissa &= 0x03 ;
638-                 mantissa <<= 1 ;
639-                 exponent -= 1 ;
640-             }
622+     uint8_t  sign     = (fp8 >> 7 ) & 0x1 ;
623+     uint8_t  exponent = (fp8 >> mantissa_bits) & (0x7F  >> mantissa_bits);
624+     uint8_t  mantissa = fp8 & mantissa_max;
641625
642-             result |= (mantissa & 0x03 ) << 21 ;
643-             result |= exponent << 23 ;
626+     uint16_t  fp16_sign     = sign << 15 ;
627+     uint16_t  fp16_exponent = (exponent + (f16_bias - exponent_bias));
628+     if  (exponent == 0 ) {
629+         //  subnormal numbers
630+         fp16_exponent++;
631+         //  mantissa != 0 because (fp8 & 0x7F) != 0 && exponent == 0
632+         while  (!(mantissa >> mantissa_bits)) {
633+             mantissa <<= 1 ;
634+             fp16_exponent--;
644635        }
645-     } else  {
646-         result |= mantissa << 20 ;
647-         exponent += 0x7f  - exponent_bias;
648-         result |= exponent << 23 ;
636+         mantissa &= mantissa_max;
649637    }
638+     uint16_t  fp16_mantissa = mantissa << 6 ;
650639
651-     return  ggml_fp32_to_fp16 (* reinterpret_cast < const   float *>(&result)) ;
640+     return  fp16_sign | fp16_exponent <<  10  | fp16_mantissa ;
652641}
653642
654- uint16_t  f8_e5m2_to_f16 (uint8_t  fp8) {
655-     uint8_t  sign     = (fp8 >> 7 ) & 0x1 ;
656-     uint8_t  exponent = (fp8 >> 2 ) & 0x1F ;
657-     uint8_t  mantissa = fp8 & 0x3 ;
658- 
659-     uint16_t  fp16_sign = sign << 15 ;
660-     uint16_t  fp16_exponent;
661-     uint16_t  fp16_mantissa;
662- 
663-     if  (exponent == 0  && mantissa == 0 ) {  //  zero
664-         return  fp16_sign;
643+ uint16_t  f8_e4m3_to_f16 (uint8_t  fp8) {
644+     //  do we need to support uz?
645+     if  ((fp8 & 0x7F ) == 0  || (fp8 & 0x7F ) == 0x7F ) {
646+         //  +/- 0 or NaN
647+         return  static_cast <uint16_t >(fp8) << 8 ;
665648    }
649+     const  uint32_t  exponent_bias = 0x7 ;  //  2^(4-1)-1
650+     const  uint32_t  f16_bias      = 0xF ;  //  2^(5-1)-1
651+     const  int  mantissa_bits      = 3 ;
652+     const  int  mantissa_max       = 0x7 ;  //  2^3-1
666653
667-     if  (exponent == 0x1F ) {  //  NAN and INF
668-         fp16_exponent = 0x1F ;
669-         fp16_mantissa = mantissa ? (mantissa << 8 ) : 0 ;
670-         return  fp16_sign | (fp16_exponent << 10 ) | fp16_mantissa;
671-     }
654+     uint8_t  sign     = (fp8 >> 7 ) & 0x1 ;
655+     uint8_t  exponent = (fp8 >> mantissa_bits) & (0x7F  >> mantissa_bits);
656+     uint8_t  mantissa = fp8 & mantissa_max;
672657
673-     if  (exponent == 0 ) {  //  subnormal numbers
674-         fp16_exponent = 0 ;
675-         fp16_mantissa = (mantissa << 8 );
676-         return  fp16_sign | fp16_mantissa;
658+     uint16_t  fp16_sign     = sign << 15 ;
659+     uint16_t  fp16_exponent = (exponent + (f16_bias - exponent_bias));
660+     if  (exponent == 0 ) {
661+         //  subnormal numbers
662+         fp16_exponent++;
663+         //  mantissa != 0 because (fp8 & 0x7F) != 0 && exponent == 0
664+         while  (!(mantissa >> mantissa_bits)) {
665+             mantissa <<= 1 ;
666+             fp16_exponent--;
667+         }
668+         mantissa &= mantissa_max;
677669    }
670+     uint16_t  fp16_mantissa = mantissa << 7 ;
678671
679-     //  normal numbers
680-     int16_t  true_exponent = (int16_t )exponent - 15  + 15 ;
681-     if  (true_exponent <= 0 ) {
682-         fp16_exponent = 0 ;
683-         fp16_mantissa = (mantissa << 8 );
684-     } else  if  (true_exponent >= 0x1F ) {
685-         fp16_exponent = 0x1F ;
686-         fp16_mantissa = 0 ;
687-     } else  {
688-         fp16_exponent = (uint16_t )true_exponent;
689-         fp16_mantissa = mantissa << 8 ;
690-     }
672+     return  fp16_sign | fp16_exponent << 10  | fp16_mantissa;
673+ }
691674
692-     return  fp16_sign | (fp16_exponent << 10 ) | fp16_mantissa;
675+ uint16_t  f8_e5m2_to_f16_b (uint8_t  fp8) {
676+     //  do we need to support fnuz?
677+     return  static_cast <uint16_t >(fp8) << 8 ;
693678}
694679
695680void  bf16_to_f32_vec (uint16_t * src, float * dst, int64_t  n) {
@@ -699,6 +684,13 @@ void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
699684    }
700685}
701686
687+ void  f8_e3m4_to_f16_vec (uint8_t * src, uint16_t * dst, int64_t  n) {
688+     //  support inplace op
689+     for  (int64_t  i = n - 1 ; i >= 0 ; i--) {
690+         dst[i] = f8_e3m4_to_f16 (src[i]);
691+     }
692+ }
693+ 
702694void  f8_e4m3_to_f16_vec (uint8_t * src, uint16_t * dst, int64_t  n) {
703695    //  support inplace op
704696    for  (int64_t  i = n - 1 ; i >= 0 ; i--) {
@@ -946,6 +938,8 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
946938        ttype = GGML_TYPE_F32;
947939    } else  if  (dtype == " F32" 
948940        ttype = GGML_TYPE_F32;
941+     } else  if  (dtype == " F8_E3M4" 
942+         ttype = GGML_TYPE_F16;
949943    } else  if  (dtype == " F8_E4M3" 
950944        ttype = GGML_TYPE_F16;
951945    } else  if  (dtype == " F8_E5M2" 
@@ -1059,6 +1053,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10591053        if  (dtype == " BF16" 
10601054            tensor_storage.is_bf16  = true ;
10611055            GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size * 2 );
1056+         } else  if  (dtype == " F8_E3M4" 
1057+             tensor_storage.is_f8_e3m4  = true ;
1058+             //  f8 -> f16
1059+             GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size * 2 );
10621060        } else  if  (dtype == " F8_E4M3" 
10631061            tensor_storage.is_f8_e4m3  = true ;
10641062            //  f8 -> f16
@@ -1461,10 +1459,10 @@ SDVersion ModelLoader::get_sd_version() {
14611459    TensorStorage token_embedding_weight, input_block_weight;
14621460    bool  input_block_checked = false ;
14631461
1464-     bool  has_multiple_encoders    = false ;
1465-     bool  is_unet = false ;
1462+     bool  has_multiple_encoders = false ;
1463+     bool  is_unet                = false ;
14661464
1467-     bool  is_xl = false ;
1465+     bool  is_xl    = false ;
14681466    bool  is_flux = false ;
14691467
14701468#define  found_family  (is_xl || is_flux)
@@ -1481,7 +1479,7 @@ SDVersion ModelLoader::get_sd_version() {
14811479            }
14821480            if  (tensor_storage.name .find (" model.diffusion_model.input_blocks." 
14831481                is_unet = true ;
1484-                 if (has_multiple_encoders){
1482+                 if   (has_multiple_encoders)  {
14851483                    is_xl = true ;
14861484                    if  (input_block_checked) {
14871485                        break ;
@@ -1490,7 +1488,7 @@ SDVersion ModelLoader::get_sd_version() {
14901488            }
14911489            if  (tensor_storage.name .find (" conditioner.embedders.1" name .find (" cond_stage_model.1" 
14921490                has_multiple_encoders = true ;
1493-                 if (is_unet){
1491+                 if   (is_unet)  {
14941492                    is_xl = true ;
14951493                    if  (input_block_checked) {
14961494                        break ;
@@ -1779,6 +1777,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
17791777                    if  (tensor_storage.is_bf16 ) {
17801778                        //  inplace op
17811779                        bf16_to_f32_vec ((uint16_t *)dst_tensor->data , (float *)dst_tensor->data , tensor_storage.nelements ());
1780+                     } else  if  (tensor_storage.is_f8_e3m4 ) {
1781+                         //  inplace op
1782+                         f8_e3m4_to_f16_vec ((uint8_t *)dst_tensor->data , (uint16_t *)dst_tensor->data , tensor_storage.nelements ());
17821783                    } else  if  (tensor_storage.is_f8_e4m3 ) {
17831784                        //  inplace op
17841785                        f8_e4m3_to_f16_vec ((uint8_t *)dst_tensor->data , (uint16_t *)dst_tensor->data , tensor_storage.nelements ());
@@ -1793,6 +1794,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
17931794                    if  (tensor_storage.is_bf16 ) {
17941795                        //  inplace op
17951796                        bf16_to_f32_vec ((uint16_t *)read_buffer.data (), (float *)read_buffer.data (), tensor_storage.nelements ());
1797+                     } else  if  (tensor_storage.is_f8_e3m4 ) {
1798+                         //  inplace op
1799+                         f8_e3m4_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
17961800                    } else  if  (tensor_storage.is_f8_e4m3 ) {
17971801                        //  inplace op
17981802                        f8_e4m3_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
@@ -1811,6 +1815,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
18111815                if  (tensor_storage.is_bf16 ) {
18121816                    //  inplace op
18131817                    bf16_to_f32_vec ((uint16_t *)read_buffer.data (), (float *)read_buffer.data (), tensor_storage.nelements ());
1818+                 } else  if  (tensor_storage.is_f8_e3m4 ) {
1819+                     //  inplace op
1820+                     f8_e3m4_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
18141821                } else  if  (tensor_storage.is_f8_e4m3 ) {
18151822                    //  inplace op
18161823                    f8_e4m3_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
0 commit comments