2
2
from typing import Dict , Tuple
3
3
4
4
import torch
5
- from torch import Tensor
6
5
import torch .nn .functional as F
6
+ from torch import Tensor
7
7
8
8
9
- # pyre-fixme[5]: Global expression must be annotated.
10
9
aten = torch .ops .aten
11
- # pyre-fixme[5]: Global expression must be annotated.
10
+
12
11
c10d_functional = torch .ops .c10d_functional
13
12
14
13
from typing import Any
15
- # pyre-fixme[5]: Global annotation cannot contain `Any`.
14
+
16
15
NF4_OPS_TABLE : Dict [Any , Any ] = {}
17
16
18
17
19
- # pyre-fixme[3]: Return type must be annotated.
20
18
def same_metadata (a : "NF4Tensor" , b : "NF4Tensor" ):
21
19
both_nf4 = isinstance (a , NF4Tensor ) and isinstance (b , NF4Tensor )
22
20
return (
23
- both_nf4 and
24
- a .block_size == b .block_size
21
+ both_nf4
22
+ and a .block_size == b .block_size
25
23
and a .scaler_block_size == b .scaler_block_size
26
24
and a .n_blocks == b .n_blocks
27
25
)
28
26
29
- # pyre-fixme[3]: Return type must be annotated.
30
- # pyre-fixme[2]: Parameter must be annotated.
27
+
31
28
def implements (aten_ops ):
32
29
"""Use this decorator to implement a function for an aten op in __torch_dispatch__"""
33
30
34
- # pyre-fixme[53]: Captured variable `aten_ops` is not annotated.
35
- # pyre-fixme[3]: Return type must be annotated.
36
- # pyre-fixme[2]: Parameter must be annotated.
37
31
def decorator (func ):
38
32
for op in aten_ops :
39
33
NF4_OPS_TABLE [op ] = func
40
34
return func
41
35
42
36
return decorator
43
37
38
+
44
39
@implements ([torch .ops .aten .detach .default , torch .ops .aten .detach ])
45
- # pyre-fixme[3]: Return type must be annotated.
46
- # pyre-fixme[2]: Parameter must be annotated.
47
40
def noop_detach (func , * args , ** kwargs ):
48
41
return args [0 ][0 ]
49
42
43
+
50
44
@implements ([torch .ops .aten ._to_copy .default ])
51
- # pyre-fixme[3]: Return type must be annotated.
52
- # pyre-fixme[2]: Parameter must be annotated.
53
45
def _to_copy (func , * args , ** kwargs ):
54
46
if not args [0 ][0 ].is_contiguous ():
55
47
assert args [0 ][0 ].t ().is_contiguous ()
56
48
return func (args [0 ][0 ].t ()).t ()
57
- return args [0 ][0 ].get_original_weight ().to (args [1 ]['dtype' ])
49
+ return args [0 ][0 ].get_original_weight ().to (args [1 ]["dtype" ])
50
+
58
51
59
52
@implements ([torch .ops .aten .to .dtype ])
60
- # pyre-fixme[3]: Return type must be annotated.
61
- # pyre-fixme[2]: Parameter must be annotated.
62
53
def to_dtype (func , * args , ** kwargs ):
63
54
if not args [0 ][0 ].is_contiguous ():
64
55
assert args [0 ][0 ].t ().is_contiguous ()
65
56
return torch .ops .aten .to .dtype (args [0 ][0 ].t (), args [0 ][1 ]).t ()
66
57
return args [0 ][0 ].get_original_weight ().to (args [0 ][1 ])
67
58
59
+
68
60
@implements ([torch .ops .aten .t .default ])
69
- # pyre-fixme[3]: Return type must be annotated.
70
- # pyre-fixme[2]: Parameter must be annotated.
71
61
def t_default (func , * args , ** kwargs ):
72
62
a = args [0 ][0 ]
73
63
tensor_meta = SubclassTensorArgs (
74
- a .size (),
75
- (a .stride (1 ), a .stride (0 )),
76
- a .storage_offset (),
77
- torch .bits2x4 ,
78
- a .device ,
79
- a .requires_grad )
64
+ a .size (),
65
+ (a .stride (1 ), a .stride (0 )),
66
+ a .storage_offset (),
67
+ torch .bits2x4 ,
68
+ a .device ,
69
+ a .requires_grad ,
70
+ )
80
71
b = NF4Tensor (
81
- tensor_meta ,
82
- a .block_size ,
83
- a .n_blocks ,
84
- a .scaler_block_size ,
85
- a .quantized_scalers ,
86
- a .quantization_factor ,
87
- a .scaler_mean ,
88
- a .quantized_data ,
89
- a .nf4 )
72
+ tensor_meta ,
73
+ a .block_size ,
74
+ a .n_blocks ,
75
+ a .scaler_block_size ,
76
+ a .quantized_scalers ,
77
+ a .quantization_factor ,
78
+ a .scaler_mean ,
79
+ a .quantized_data ,
80
+ a .nf4 ,
81
+ )
90
82
return b
91
83
84
+
92
85
@implements ([torch .ops .aten .mm .default ])
93
- # pyre-fixme[3]: Return type must be annotated.
94
- # pyre-fixme[2]: Parameter must be annotated.
95
86
def mm_default (func , * args , ** kwargs ):
96
87
return linear_nf4 (args [0 ][0 ], args [0 ][1 ])
97
88
@@ -101,14 +92,12 @@ def mm_default(func, *args, **kwargs):
101
92
aten .copy_ .default ,
102
93
]
103
94
)
104
- # pyre-fixme[3]: Return type must be annotated.
105
- # pyre-fixme[2]: Parameter must be annotated.
106
95
def copy_ (func , * args , ** kwargs ):
107
96
original : NF4Tensor = args [0 ][0 ]
108
97
copy_in : torch .Tensor = args [0 ][1 ]
109
98
110
99
# Base Case
111
- # pyre-fixme[6]: For 2nd argument expected `NF4Tensor` but got `Tensor`.
100
+
112
101
if same_metadata (original , copy_in ):
113
102
original_tensors = original .__tensor_flatten__ ()[0 ]
114
103
for tensor_name in original_tensors :
@@ -117,7 +106,9 @@ def copy_(func, *args, **kwargs):
117
106
118
107
# Convert Non NF4Tensor into NF4 for copy in
119
108
if not isinstance (copy_in , NF4Tensor ):
120
- copy_in_nf4 = NF4Tensor .from_tensor (copy_in , original .block_size , original .scaler_block_size )
109
+ copy_in_nf4 = NF4Tensor .from_tensor (
110
+ copy_in , original .block_size , original .scaler_block_size
111
+ )
121
112
return original .copy_ (copy_in_nf4 )
122
113
123
114
# Other Tensor is not a NF4Tensor
@@ -127,10 +118,11 @@ def copy_(func, *args, **kwargs):
127
118
)
128
119
return original .copy_ (same_meta_nf4 )
129
120
121
+
130
122
@dataclass
131
123
class SubclassTensorArgs :
132
124
original_shape : torch .Size
133
- # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
125
+
134
126
original_strides : Tuple
135
127
storage_offset : int
136
128
dtype : torch .dtype
@@ -161,7 +153,6 @@ def get_block_absmax(inpt_tensor: torch.Tensor, block_size: int) -> torch.Tensor
161
153
class NF4Tensor (torch .Tensor ):
162
154
"""NF4Tensor class for converting a weight to the QLoRA NF4 format"""
163
155
164
- # pyre-fixme[3]: Return type must be annotated.
165
156
def __new__ (
166
157
cls ,
167
158
# Args related for base tensor construction
@@ -190,7 +181,6 @@ def __new__(
190
181
191
182
"""
192
183
193
- # pyre-fixme[16]: `Tensor` has no attribute `_make_wrapper_subclass`.
194
184
nf4tensor = torch .Tensor ._make_wrapper_subclass (
195
185
cls ,
196
186
tensor_meta .original_shape ,
@@ -203,7 +193,6 @@ def __new__(
203
193
)
204
194
return nf4tensor
205
195
206
- # pyre-fixme[3]: Return type must be annotated.
207
196
def __init__ (
208
197
self ,
209
198
tensor_meta : SubclassTensorArgs ,
@@ -228,7 +217,6 @@ def __init__(
228
217
229
218
@classmethod
230
219
@torch .no_grad ()
231
- # pyre-fixme[3]: Return type must be annotated.
232
220
def from_tensor (
233
221
cls ,
234
222
inpt_tensor : torch .Tensor ,
@@ -342,7 +330,6 @@ def double_quantize_scalers(
342
330
n_scaler_blocks , scaler_block_size
343
331
)
344
332
345
- # pyre-fixme[58]: `/` is not supported for operand types `int` and `Tensor`.
346
333
quantization_factor = 256 / (2 * scaler_absmax )
347
334
# Length equal to weight numel // block_size
348
335
quantized_scaler_blocks = scaler_blocks * quantization_factor
@@ -352,7 +339,7 @@ def double_quantize_scalers(
352
339
# This is needed to make sure that quantization_factor remains a repeated view of n_scaler_blocks
353
340
# For some reason the 127/scaler_absmax realizes n_scaler entries when only n_scaler_blocks are needed
354
341
# The following will grab the first entry for the n_scaler_blocks which is the same across the scaler_block_size
355
- # pyre-fixme[16]: `float` has no attribute `__getitem__`.
342
+
356
343
quantization_factor = quantization_factor [:, 0 ]
357
344
358
345
return (
@@ -389,7 +376,6 @@ def dequantize_scalers(
389
376
390
377
@staticmethod
391
378
def convert_to_norm_float_weight (
392
- # pyre-fixme[11]: Annotation `tensor` is not defined as a type.
393
379
inpt_tensor : torch .Tensor , n_blocks : int , block_size : int , nf4 : torch .tensor
394
380
) -> torch .Tensor :
395
381
"""Convert a tensor to the normalized float weight format"""
@@ -450,7 +436,6 @@ def get_original_weight(self) -> torch.Tensor:
450
436
451
437
@staticmethod
452
438
def quantize_tensor_nearest (
453
- # pyre-fixme[11]: Annotation `float16` is not defined as a type.
454
439
value : torch .float16 , nf4 : torch .Tensor
455
440
) -> torch .Tensor :
456
441
"""Quantize a float16 tensor to nf4 format to nearest and not rounded up"""
@@ -461,9 +446,9 @@ def quantize_tensor_nearest(
461
446
return closest_nf4
462
447
463
448
@staticmethod
464
- # pyre-fixme[14]: `dequantize` overrides method defined in `TensorBase`
449
+
465
450
# inconsistently.
466
- # pyre-fixme[40]: Static method `dequantize` cannot override a non-static method
451
+
467
452
# defined in `torch._C.TensorBase`.
468
453
def dequantize (value : torch .Tensor , nf4 : torch .Tensor ) -> torch .Tensor :
469
454
"""Dequantize a nf4 value to bfloat16 format"""
@@ -475,7 +460,7 @@ def unpack(
475
460
) -> Tuple [
476
461
int , int , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Size
477
462
]:
478
- # pyre-fixme[7]: Expected `Tuple[int, int, Tensor, Tensor, Tensor, Tensor,
463
+
479
464
# Size]` but got `Tuple[int, int, int, Tensor, Tensor, Tensor, Tensor]`.
480
465
return (
481
466
self .block_size ,
@@ -487,15 +472,12 @@ def unpack(
487
472
self .quantized_data ,
488
473
)
489
474
490
- # pyre-fixme[14]: `__repr__` overrides method defined in `Tensor` inconsistently.
491
- # pyre-fixme[3]: Return type must be annotated.
492
475
def __repr__ (self ):
493
476
return f"Quantized Data: { self .quantized_data } \n Scalers: { self .quantized_scalers } \n "
494
477
495
478
def __str__ (self ):
496
479
return f"NF4Tensor({ self .shape } , { self .block_size } )"
497
480
498
- # pyre-fixme[3]: Return type must be annotated.
499
481
def __tensor_flatten__ (self ):
500
482
tensor_meta = SubclassTensorArgs (
501
483
self .shape ,
@@ -520,10 +502,9 @@ def __tensor_flatten__(self):
520
502
], ctx
521
503
522
504
@staticmethod
523
- # pyre-fixme[3]: Return type must be annotated.
524
- # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
505
+
525
506
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
526
- # pyre-fixme[2]: Parameter must be annotated.
507
+
527
508
def __tensor_unflatten__ (inner_tensors : Dict , metadata , outer_size , outer_stride ):
528
509
assert len (inner_tensors ) == 5 , "Expected 5 inner tensors"
529
510
return NF4Tensor (
@@ -538,28 +519,25 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride
538
519
inner_tensors ["nf4" ],
539
520
)
540
521
541
-
542
- # pyre-fixme[3]: Return type must be annotated.
543
522
def __str__ (self ):
544
523
return self .to (torch .float32 ).__str__ ()
545
524
546
525
@classmethod
547
- # pyre-fixme[3]: Return type must be annotated.
548
- # pyre-fixme[2]: Parameter must be annotated.
549
526
def __torch_dispatch__ (cls , func , types , args , kwargs = None ):
550
527
"""TODO we are not supporting torch dispatch at the moment
551
528
instead we have created a Autograd.Function to handle the linear
552
529
"""
553
530
# All ops in the NF4_OPS_TABLE expect NF4 Tensors as inputs
554
531
# And don't support mixed tensor subclasses. This will trigger the handler for
555
532
# the next type in the dispatch list
556
- # pyre-fixme[3]: Return type must be annotated.
557
- # pyre-fixme[2]: Parameter must be annotated.
533
+
558
534
def allowed_subclasses (type ):
559
535
return (
560
536
issubclass (cls , type )
561
537
or issubclass (torch ._subclasses .fake_tensor .FakeTensor , type )
562
- or issubclass (torch ._subclasses .functional_tensor .FunctionalTensor , type )
538
+ or issubclass (
539
+ torch ._subclasses .functional_tensor .FunctionalTensor , type
540
+ )
563
541
)
564
542
565
543
if not all (allowed_subclasses (t ) for t in types ):
@@ -572,25 +550,24 @@ def allowed_subclasses(type):
572
550
)
573
551
574
552
# Do not force the Float8Tensor type on the returned tensor
575
- # pyre-fixme[4]: Attribute must be annotated.
553
+
576
554
__torch_function__ = torch ._C ._disabled_torch_function_impl
577
555
556
+
578
557
class LinearNF4 (torch .autograd .Function ):
579
558
@staticmethod
580
- # pyre-fixme[14]: `forward` overrides method defined in `_SingleLevelFunction`
559
+
581
560
# inconsistently.
582
- # pyre-fixme[3]: Return type must be annotated.
583
- # pyre-fixme[2]: Parameter must be annotated.
561
+
584
562
def forward (ctx , input : torch .Tensor , weight : NF4Tensor ):
585
563
"""Save the quantized nf4 weight for backward pass"""
586
564
ctx .nf4_weight = weight
587
565
return F .linear (input , weight .to (input .dtype ))
588
566
589
567
@staticmethod
590
- # pyre-fixme[14]: `backward` overrides method defined in `_SingleLevelFunction`
568
+
591
569
# inconsistently.
592
- # pyre-fixme[3]: Return type must be annotated.
593
- # pyre-fixme[2]: Parameter must be annotated.
570
+
594
571
def backward (ctx , grad_output ):
595
572
"""The nf4 weight will never require grad so we can just return the grad_output @ weight.get_original_weight()"""
596
573
weight : NF4Tensor = ctx .nf4_weight
@@ -606,10 +583,7 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor:
606
583
"""
607
584
return LinearNF4 .apply (input , weight )
608
585
609
- # pyre-fixme[3]: Return type must be annotated.
610
- # pyre-fixme[2]: Parameter must be annotated.
611
- def to_nf4 (tensor ,
612
- block_size : int = 64 ,
613
- scaler_block_size : int = 256 ):
586
+
587
+ def to_nf4 (tensor , block_size : int = 64 , scaler_block_size : int = 256 ):
614
588
tensor1 = tensor .to (torch .bfloat16 )
615
589
return NF4Tensor .from_tensor (tensor1 , block_size , scaler_block_size )
0 commit comments