@@ -574,6 +574,87 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
574
574
return (d * q ).reshape ((n_blocks , QK_K ))
575
575
576
576
577
+ class TQ1_0 (__Quant , qtype = GGMLQuantizationType .TQ1_0 ):
578
+ @classmethod
579
+ def quantize_blocks (cls , blocks : np .ndarray ) -> np .ndarray :
580
+ n_blocks = blocks .shape [0 ]
581
+
582
+ d = abs (blocks ).max (axis = - 1 , keepdims = True )
583
+ with np .errstate (divide = "ignore" ):
584
+ id = np .where (d == 0 , 0 , 1 / d )
585
+ qs = np_roundf (blocks * id )
586
+ qs = (qs .astype (np .int8 ) + np .int8 (1 )).astype (np .uint8 )
587
+
588
+ qs0 , qs1 , qh = qs [..., :(32 * 5 )], qs [..., (32 * 5 ):(48 * 5 )], qs [..., (48 * 5 ):]
589
+ qs0 = qs0 .reshape ((n_blocks , - 1 , 5 , 32 )) * np .array ([81 , 27 , 9 , 3 , 1 ], dtype = np .uint8 ).reshape ((1 , 1 , 5 , 1 ))
590
+ qs0 = np .sum (qs0 , axis = - 2 ).reshape ((n_blocks , - 1 ))
591
+ qs1 = qs1 .reshape ((n_blocks , - 1 , 5 , 16 )) * np .array ([81 , 27 , 9 , 3 , 1 ], dtype = np .uint8 ).reshape ((1 , 1 , 5 , 1 ))
592
+ qs1 = np .sum (qs1 , axis = - 2 ).reshape ((n_blocks , - 1 ))
593
+ qh = qh .reshape ((n_blocks , - 1 , 4 , 4 )) * np .array ([81 , 27 , 9 , 3 ], dtype = np .uint8 ).reshape ((1 , 1 , 4 , 1 ))
594
+ qh = np .sum (qh , axis = - 2 ).reshape ((n_blocks , - 1 ))
595
+ qs = np .concatenate ([qs0 , qs1 , qh ], axis = - 1 )
596
+ qs = (qs .astype (np .uint16 ) * 256 + (243 - 1 )) // 243
597
+
598
+ qs = qs .astype (np .uint8 )
599
+ d = d .astype (np .float16 ).view (np .uint8 )
600
+
601
+ return np .concatenate ([qs , d ], axis = - 1 )
602
+
603
+ @classmethod
604
+ def dequantize_blocks (cls , blocks : np .ndarray ) -> np .ndarray :
605
+ n_blocks = blocks .shape [0 ]
606
+
607
+ qs , rest = np .hsplit (blocks , [(QK_K - 4 * QK_K // 64 ) // 5 ])
608
+ qh , d = np .hsplit (rest , [QK_K // 64 ])
609
+
610
+ d = d .view (np .float16 ).astype (np .float32 )
611
+
612
+ qs0 , qs1 = qs [..., :32 ], qs [..., 32 :]
613
+ qs0 = qs0 .reshape ((n_blocks , - 1 , 1 , 32 )) * np .array ([1 , 3 , 9 , 27 , 81 ], dtype = np .uint8 ).reshape ((1 , 1 , 5 , 1 ))
614
+ qs0 = qs0 .reshape ((n_blocks , - 1 ))
615
+ qs1 = qs1 .reshape ((n_blocks , - 1 , 1 , 16 )) * np .array ([1 , 3 , 9 , 27 , 81 ], dtype = np .uint8 ).reshape ((1 , 1 , 5 , 1 ))
616
+ qs1 = qs1 .reshape ((n_blocks , - 1 ))
617
+ qh = qh .reshape ((n_blocks , - 1 , 1 , 4 )) * np .array ([1 , 3 , 9 , 27 ], dtype = np .uint8 ).reshape ((1 , 1 , 4 , 1 ))
618
+ qh = qh .reshape ((n_blocks , - 1 ))
619
+ qs = np .concatenate ([qs0 , qs1 , qh ], axis = - 1 )
620
+ qs = ((qs .astype (np .uint16 ) * 3 ) >> 8 ).astype (np .int8 ) - np .int8 (1 )
621
+
622
+ return (d * qs .astype (np .float32 ))
623
+
624
+
625
+ class TQ2_0 (__Quant , qtype = GGMLQuantizationType .TQ2_0 ):
626
+ @classmethod
627
+ def quantize_blocks (cls , blocks : np .ndarray ) -> np .ndarray :
628
+ n_blocks = blocks .shape [0 ]
629
+
630
+ d = abs (blocks ).max (axis = - 1 , keepdims = True )
631
+ with np .errstate (divide = "ignore" ):
632
+ id = np .where (d == 0 , 0 , 1 / d )
633
+ qs = np_roundf (blocks * id )
634
+ qs = (qs .astype (np .int8 ) + np .int8 (1 )).astype (np .uint8 )
635
+
636
+ qs = qs .reshape ((n_blocks , - 1 , 4 , 32 )) << np .array ([0 , 2 , 4 , 6 ], dtype = np .uint8 ).reshape ((1 , 1 , 4 , 1 ))
637
+ qs = qs [..., 0 , :] | qs [..., 1 , :] | qs [..., 2 , :] | qs [..., 3 , :]
638
+ qs = qs .reshape ((n_blocks , - 1 ))
639
+
640
+ d = d .astype (np .float16 ).view (np .uint8 )
641
+
642
+ return np .concatenate ([qs , d ], axis = - 1 )
643
+
644
+ @classmethod
645
+ def dequantize_blocks (cls , blocks : np .ndarray ) -> np .ndarray :
646
+ n_blocks = blocks .shape [0 ]
647
+
648
+ qs , d = np .hsplit (blocks , [QK_K // 4 ])
649
+
650
+ d = d .view (np .float16 ).astype (np .float32 )
651
+
652
+ qs = qs .reshape ((n_blocks , - 1 , 1 , 32 )) >> np .array ([0 , 2 , 4 , 6 ], dtype = np .uint8 ).reshape ((1 , 1 , 4 , 1 ))
653
+ qs = (qs & 0x03 ).reshape ((n_blocks , - 1 )).astype (np .int8 ) - np .int8 (1 )
654
+
655
+ return (d * qs .astype (np .float32 ))
656
+
657
+
577
658
class IQ2_XXS (__Quant , qtype = GGMLQuantizationType .IQ2_XXS ):
578
659
ksigns : bytes = (
579
660
b"\x00 \x81 \x82 \x03 \x84 \x05 \x06 \x87 \x88 \x09 \x0a \x8b \x0c \x8d \x8e \x0f "
0 commit comments