@@ -630,6 +630,11 @@ def extra_repr(self):
630
630
return f"inner_k_tiles={ self .inner_k_tiles } "
631
631
632
632
633
+ @dataclass (frozen = True )
634
+ class Int4CPULayout (Layout ):
635
+ def pre_process (self , input : torch .Tensor ) -> torch .Tensor :
636
+ return input
637
+
633
638
@dataclass (frozen = True )
634
639
class Float8Layout (Layout ):
635
640
mm_config : Optional [Float8MMConfig ] = None
@@ -1616,6 +1621,230 @@ def get_layout(self) -> Layout:
1616
1621
return self ._layout
1617
1622
1618
1623
1624
+ @register_layout (Int4CPULayout )
1625
+ class Int4CPUAQTTensorImpl (AQTTensorImpl ):
1626
+ """
1627
+ TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only,
1628
+ used by tinygemm kernels `_weight_int4pack_mm`
1629
+
1630
+ It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of
1631
+ dimension: [n][k / 2] (uint8 dtype)
1632
+ (unpacked Tensor shape is n * k)
1633
+
1634
+ Note: we also pack scale and zero point together here for tinygemm kernel
1635
+
1636
+ Note: technically Int4 CPU layout should be the layout for the underlying packed weight
1637
+ (int Tensor) but since the scale and zero_point are also packed into the same tensor here which is not used
1638
+ in plain layout, we just created a layout for AQT right now, this could be improved if we split out
1639
+ int4 aqt into a separate tensor subclass
1640
+
1641
+ fields:
1642
+ packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout
1643
+ scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor
1644
+ """
1645
+
1646
+ def __new__ (
1647
+ cls ,
1648
+ packed_weight : torch .Tensor ,
1649
+ scale_and_zero : torch .Tensor ,
1650
+ transposed : bool ,
1651
+ _layout : Layout ,
1652
+ ):
1653
+ kwargs = {}
1654
+ kwargs ["device" ] = packed_weight .device
1655
+ kwargs ["layout" ] = (
1656
+ kwargs .get ("layout" )
1657
+ if kwargs .get ("layout" , False )
1658
+ else packed_weight .layout
1659
+ )
1660
+ kwargs ["dtype" ] = packed_weight .dtype
1661
+ kwargs ["requires_grad" ] = False
1662
+ shape = packed_weight .shape
1663
+ return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
1664
+
1665
+ def __init__ (
1666
+ self ,
1667
+ packed_weight : torch .Tensor ,
1668
+ scale_and_zero : torch .Tensor ,
1669
+ transposed : bool ,
1670
+ _layout : Layout ,
1671
+ ):
1672
+ self .packed_weight = packed_weight
1673
+ self .scale_and_zero = scale_and_zero
1674
+ self .transposed = False
1675
+ self ._layout = _layout
1676
+
1677
+ def __tensor_flatten__ (self ):
1678
+ return ["packed_weight" , "scale_and_zero" ], [self .transposed , self ._layout ]
1679
+
1680
+ @classmethod
1681
+ def __tensor_unflatten__ (
1682
+ cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
1683
+ ):
1684
+ packed_weight , scale_and_zero = (
1685
+ tensor_data_dict ["packed_weight" ],
1686
+ tensor_data_dict ["scale_and_zero" ],
1687
+ )
1688
+ (
1689
+ transposed ,
1690
+ _layout ,
1691
+ ) = tensor_attributes
1692
+ return cls (packed_weight , scale_and_zero , transposed , _layout )
1693
+
1694
+ @classmethod
1695
+ def from_plain (
1696
+ cls ,
1697
+ int_data : torch .Tensor ,
1698
+ scale : torch .Tensor ,
1699
+ zero_point : Optional [torch .Tensor ],
1700
+ _layout : Layout ,
1701
+ ):
1702
+ assert isinstance (_layout , Int4CPULayout )
1703
+
1704
+ assert (
1705
+ int_data .dtype == torch .int32
1706
+ ), "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
1707
+ packed_weight = torch .ops .aten ._convert_weight_to_int4pack_for_cpu (
1708
+ int_data , 1 # TODO:remove
1709
+ )
1710
+ scale = scale .reshape (int_data .shape [0 ], - 1 )
1711
+ zero_point = zero_point .reshape (int_data .shape [0 ], - 1 )
1712
+
1713
+ scale_and_zero = pack_tinygemm_scales_and_zeros (scale , zero_point )
1714
+ return cls (packed_weight , scale_and_zero , False , _layout )
1715
+
1716
+ def to (self , * args , ** kwargs ):
1717
+ kwargs = self ._get_to_kwargs (* args , ** kwargs )
1718
+ device = kwargs ["device" ]
1719
+ return self .__class__ (
1720
+ self .packed_weight .to (device ),
1721
+ self .scale_and_zero .to (device ),
1722
+ self .transposed ,
1723
+ self ._layout ,
1724
+ )
1725
+
1726
+ def _apply_fn_to_data (self , fn ):
1727
+ # self.packed_weight = fn(self.packed_weight)
1728
+ # self.scale_and_zero = fn(self.scale_and_zero)
1729
+ # return self
1730
+ return self .__class__ (
1731
+ fn (self .packed_weight ),
1732
+ fn (self .scale_and_zero ),
1733
+ self .transposed ,
1734
+ self ._layout ,
1735
+ )
1736
+
1737
+ @classmethod
1738
+ def __torch_dispatch__ (cls , func , types , args , kwargs ):
1739
+ kwargs = {} if kwargs is None else kwargs
1740
+
1741
+ if func is aten .detach .default :
1742
+ return return_and_correct_aliasing (
1743
+ func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
1744
+ )
1745
+
1746
+ if func is aten .clone .default :
1747
+ return return_and_correct_aliasing (
1748
+ func , args , kwargs , args [0 ]._apply_fn_to_data (torch .clone )
1749
+ )
1750
+
1751
+ if func is aten .t .default :
1752
+ """we don't need to repack the weight and just rely on external
1753
+ shape being changed and record the status of transpose/no-transpose
1754
+ """
1755
+ transposed = Int4CPUAQTTensorImpl (
1756
+ args [0 ].packed_weight ,
1757
+ args [0 ].scale_and_zero ,
1758
+ not args [0 ].transposed ,
1759
+ args [0 ]._layout ,
1760
+ )
1761
+ return return_and_correct_aliasing (func , args , kwargs , transposed )
1762
+
1763
+ if func is aten .slice .Tensor :
1764
+ self , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
1765
+ if dim == 0 :
1766
+ int_data , scale , zero_point = self .get_plain ()
1767
+ int_data = aten .slice .Tensor (int_data , dim , start , end , step )
1768
+ # this is to handle padding
1769
+ int_data = self ._layout .post_process (int_data )
1770
+ sliced = self .from_plain (int_data , scale , zero_point , self ._layout )
1771
+ return return_and_correct_aliasing (func , args , kwargs , sliced )
1772
+ elif dim == 1 :
1773
+ int_data , scale , zero_point = self .get_plain ()
1774
+ assert step == 1 , "Only step == 1 is supported in slicing right now"
1775
+ data_len = int_data .shape [dim ]
1776
+ scale_len = scale .shape [dim ]
1777
+ ratio = data_len / scale_len
1778
+ start_scale = int (start / ratio )
1779
+ end_scale = int (end / ratio )
1780
+
1781
+ int_data = aten .slice .Tensor (int_data , dim , start , end , step )
1782
+ # this is to handle padding
1783
+ int_data = self ._layout .post_process (int_data )
1784
+ scale = aten .slice .Tensor (scale , dim , start_scale , end_scale , step )
1785
+ zero_point = aten .slice .Tensor (
1786
+ zero_point , dim , start_scale , end_scale , step
1787
+ )
1788
+ sliced = self .from_plain (int_data , scale , zero_point , self ._layout )
1789
+ return sliced
1790
+ else :
1791
+ raise NotImplementedError (
1792
+ f"Int4CPUAQTTensorImpl dispatch: attempting to run { func } , with dim={ dim } , that is not supported"
1793
+ )
1794
+
1795
+ raise NotImplementedError (
1796
+ f"Int4CPUAQTTensorImpl dispatch: attempting to run { func } , this is not supported"
1797
+ )
1798
+
1799
+ __torch_function__ = torch ._C ._disabled_torch_function_impl
1800
+
1801
+ def get_plain (self ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
1802
+ from torchao .quantization .quant_primitives import (
1803
+ ZeroPointDomain ,
1804
+ quantize_affine ,
1805
+ )
1806
+ from torchao .quantization .utils import unpack_tinygemm_scales_and_zeros
1807
+
1808
+ scale , zero = unpack_tinygemm_scales_and_zeros (self .scale_and_zero )
1809
+
1810
+ cur_shape = self .shape
1811
+ assert len (cur_shape ) == 2
1812
+ original_shape = (cur_shape [0 ], cur_shape [1 ] * 2 )
1813
+ eye_shape = original_shape [1 ]
1814
+ groupsize = int (original_shape [1 ] / scale .shape [- 2 ])
1815
+ block_size = (1 , groupsize )
1816
+ device = self .device
1817
+ original_dtype = torch .bfloat16
1818
+ target_dtype = torch .int32
1819
+ quant_min = 0
1820
+ quant_max = 15
1821
+ zero_point_domain = ZeroPointDomain .FLOAT
1822
+ assert len (block_size ) == 2 and block_size [0 ] == 1
1823
+ dequantized = torch .ops .aten ._weight_int4pack_mm_for_cpu (
1824
+ torch .eye (eye_shape , device = device , dtype = original_dtype ),
1825
+ self .packed_weight ,
1826
+ groupsize ,
1827
+ self .scale_and_zero ,
1828
+ )
1829
+ dequantized = dequantized .t ().contiguous ()
1830
+ # TODO: move this to `unpack_tinygemm_scales_and_zeros`?
1831
+ scale = scale .reshape (scale .shape [:- 1 ]).contiguous ()
1832
+ zero = zero .reshape (zero .shape [:- 1 ]).contiguous ()
1833
+ int_data = quantize_affine (
1834
+ dequantized ,
1835
+ block_size ,
1836
+ scale ,
1837
+ zero ,
1838
+ target_dtype ,
1839
+ quant_min ,
1840
+ quant_max ,
1841
+ zero_point_domain ,
1842
+ )
1843
+ return int_data , scale , zero
1844
+
1845
+ def get_layout (self ) -> Layout :
1846
+ return self ._layout
1847
+
1619
1848
#####################################################
1620
1849
# torch functional and aten operator implementation #
1621
1850
#####################################################
0 commit comments