11
11
12
12
import torch
13
13
14
+ from torchao .ops import lib
14
15
from torchao .prototype .spinquant ._hadamard_matrices import get_had172 , get_had156 , get_had140 , get_had108 , get_had60 , get_had52 , get_had36 , get_had28 , get_had44 , get_had40 , get_had20 , get_had12
16
+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_4
15
17
16
18
try :
17
- from fast_hadamard_transform import hadamard_transform
19
+ from fast_hadamard_transform import hadamard_transform as _fast_hadamard_transform
18
20
19
21
def matmul_hadU (X , hadK , K ):
20
22
if X .is_cuda :
@@ -32,16 +34,59 @@ def matmul_hadU(X, hadK, K):
32
34
return matmul_hadU_slow (X , hadK , K )
33
35
34
36
37
+ def register_custom_op_impl (name ):
38
+ def decorator (func ):
39
+ if TORCH_VERSION_AT_LEAST_2_4 :
40
+ return torch .library .custom_op (f"{ name } " , mutates_args = ())(func )
41
+ else :
42
+ lib .define ("hadamard_transform(Tensor x, float scale = 0.0) -> Tensor" )
43
+ return torch .library .impl (f"{ name } " , "cuda" )(func )
44
+ return decorator
45
+
46
+
47
+ def register_custom_op_abstract (name ):
48
+ def decorator (func ):
49
+ if TORCH_VERSION_AT_LEAST_2_4 :
50
+ return torch .library .register_fake (f"{ name } " )(func )
51
+ else :
52
+ return torch .library .impl_abstract (f"{ name } " )(func )
53
+ return decorator
54
+
55
+
56
+ @register_custom_op_impl ("torchao::hadamard_transform" )
57
+ def hadamard_transform (x : torch .Tensor , scale : float = 1.0 ) -> torch .Tensor :
58
+ """
59
+ Arguments:
60
+ x: (..., dim)
61
+ scale: float. Multiply the output by this number.
62
+ Returns:
63
+ out: (..., dim)
64
+
65
+ Multiply each row of x by the Hadamard transform matrix.
66
+ Equivalent to F.linear(x, torch.tensor(scipy.linalg.hadamard(dim))) * scale.
67
+ If dim is not a power of 2, we implicitly pad x with zero so that dim is the next power of 2.
68
+
69
+ Source: https://github.com/Dao-AILab/fast-hadamard-transform
70
+ """
71
+ return _fast_hadamard_transform (x , scale )
72
+
73
+
74
+ @register_custom_op_abstract ("torchao::hadamard_transform" )
75
+ def _ (x : torch .Tensor , scale : float = 1.0 ) -> torch .Tensor :
76
+ torch ._check (x .dim () >= 1 , lambda : f"input should be at least a 1D tensor, got { x .dim ()} D" )
77
+ return torch .empty_like (x )
78
+
79
+
35
80
class HadamardTransform (torch .autograd .Function ):
36
81
"""The unnormalized Hadamard transform (i.e. without dividing by sqrt(2))"""
37
82
38
83
@staticmethod
39
84
def forward (ctx , u ):
40
- return hadamard_transform (u )
85
+ return _fast_hadamard_transform (u )
41
86
42
87
@staticmethod
43
88
def backward (ctx , grad ):
44
- return hadamard_transform (grad )
89
+ return _fast_hadamard_transform (grad )
45
90
46
91
47
92
def is_pow2 (n ):
@@ -144,9 +189,9 @@ def matmul_hadU_slow(X, hadK, K):
144
189
def matmul_hadU_fast (X , hadK , K ):
145
190
n = X .shape [- 1 ]
146
191
if K == 1 :
147
- return HadamardTransform . apply (X .contiguous ()) / torch .tensor (n ).sqrt ()
192
+ return torch . ops . torchao . hadamard_transform . default (X .contiguous ()) / torch .tensor (n ).sqrt ()
148
193
input = X .view (- 1 , K , n // K )
149
- input = HadamardTransform . apply (input .contiguous ()) / torch .tensor (n ).sqrt ()
194
+ input = torch . ops . torchao . hadamard_transform . default (input .contiguous ()) / torch .tensor (n ).sqrt ()
150
195
input = hadK .to (input .device ).to (input .dtype ) @ input
151
196
return input .reshape (X .shape )
152
197
0 commit comments