6
6
import torch .nn .functional as F
7
7
8
8
9
+ # pyre-fixme[5]: Global expression must be annotated.
9
10
aten = torch .ops .aten
11
+ # pyre-fixme[5]: Global expression must be annotated.
10
12
c10d_functional = torch .ops .c10d_functional
11
13
12
14
from typing import Any
15
+ # pyre-fixme[5]: Global annotation cannot contain `Any`.
13
16
NF4_OPS_TABLE : Dict [Any , Any ] = {}
14
17
15
18
19
+ # pyre-fixme[3]: Return type must be annotated.
16
20
def same_metadata (a : "NF4Tensor" , b : "NF4Tensor" ):
17
21
both_nf4 = isinstance (a , NF4Tensor ) and isinstance (b , NF4Tensor )
18
22
return (
@@ -22,9 +26,14 @@ def same_metadata(a: "NF4Tensor", b: "NF4Tensor"):
22
26
and a .n_blocks == b .n_blocks
23
27
)
24
28
29
+ # pyre-fixme[3]: Return type must be annotated.
30
+ # pyre-fixme[2]: Parameter must be annotated.
25
31
def implements (aten_ops ):
26
32
"""Use this decorator to implement a function for an aten op in __torch_dispatch__"""
27
33
34
+ # pyre-fixme[53]: Captured variable `aten_ops` is not annotated.
35
+ # pyre-fixme[3]: Return type must be annotated.
36
+ # pyre-fixme[2]: Parameter must be annotated.
28
37
def decorator (func ):
29
38
for op in aten_ops :
30
39
NF4_OPS_TABLE [op ] = func
@@ -33,14 +42,20 @@ def decorator(func):
33
42
return decorator
34
43
35
44
@implements ([torch .ops .aten .detach .default , torch .ops .aten .detach ])
45
+ # pyre-fixme[3]: Return type must be annotated.
46
+ # pyre-fixme[2]: Parameter must be annotated.
36
47
def noop_detach (func , * args , ** kwargs ):
37
48
return args [0 ][0 ]
38
49
39
50
@implements ([torch .ops .aten ._to_copy .default ])
51
+ # pyre-fixme[3]: Return type must be annotated.
52
+ # pyre-fixme[2]: Parameter must be annotated.
40
53
def _to_copy (func , * args , ** kwargs ):
41
54
return args [0 ][0 ].get_original_weight ().to (args [1 ]['dtype' ])
42
55
43
56
@implements ([torch .ops .aten .to .dtype ])
57
+ # pyre-fixme[3]: Return type must be annotated.
58
+ # pyre-fixme[2]: Parameter must be annotated.
44
59
def to_dtype (func , * args , ** kwargs ):
45
60
return args [0 ][0 ].get_original_weight ().to (args [0 ][1 ])
46
61
@@ -50,11 +65,14 @@ def to_dtype(func, *args, **kwargs):
50
65
aten .copy_ .default ,
51
66
]
52
67
)
68
+ # pyre-fixme[3]: Return type must be annotated.
69
+ # pyre-fixme[2]: Parameter must be annotated.
53
70
def copy_ (func , * args , ** kwargs ):
54
71
original : NF4Tensor = args [0 ][0 ]
55
72
copy_in : torch .Tensor = args [0 ][1 ]
56
73
57
74
# Base Case
75
+ # pyre-fixme[6]: For 2nd argument expected `NF4Tensor` but got `Tensor`.
58
76
if same_metadata (original , copy_in ):
59
77
original_tensors = original .__tensor_flatten__ ()[0 ]
60
78
for tensor_name in original_tensors :
@@ -76,6 +94,7 @@ def copy_(func, *args, **kwargs):
76
94
@dataclass
77
95
class SubclassTensorArgs :
78
96
original_shape : torch .Size
97
+ # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
79
98
original_strides : Tuple
80
99
storage_offset : int
81
100
dtype : torch .dtype
@@ -106,6 +125,7 @@ def get_block_absmax(inpt_tensor: torch.Tensor, block_size: int) -> torch.Tensor
106
125
class NF4Tensor (torch .Tensor ):
107
126
"""NF4Tensor class for converting a weight to the QLoRA NF4 format"""
108
127
128
+ # pyre-fixme[3]: Return type must be annotated.
109
129
def __new__ (
110
130
cls ,
111
131
# Args related for base tensor construction
@@ -134,6 +154,7 @@ def __new__(
134
154
135
155
"""
136
156
157
+ # pyre-fixme[16]: `Tensor` has no attribute `_make_wrapper_subclass`.
137
158
nf4tensor = torch .Tensor ._make_wrapper_subclass (
138
159
cls ,
139
160
tensor_meta .original_shape ,
@@ -145,6 +166,7 @@ def __new__(
145
166
)
146
167
return nf4tensor
147
168
169
+ # pyre-fixme[3]: Return type must be annotated.
148
170
def __init__ (
149
171
self ,
150
172
tensor_meta : SubclassTensorArgs ,
@@ -169,6 +191,7 @@ def __init__(
169
191
170
192
@classmethod
171
193
@torch .no_grad ()
194
+ # pyre-fixme[3]: Return type must be annotated.
172
195
def from_tensor (
173
196
cls ,
174
197
inpt_tensor : torch .Tensor ,
@@ -281,6 +304,7 @@ def double_quantize_scalers(
281
304
n_scaler_blocks , scaler_block_size
282
305
)
283
306
307
+ # pyre-fixme[58]: `/` is not supported for operand types `int` and `Tensor`.
284
308
quantization_factor = 256 / (2 * scaler_absmax )
285
309
# Length equal to weight numel // block_size
286
310
quantized_scaler_blocks = scaler_blocks * quantization_factor
@@ -290,6 +314,7 @@ def double_quantize_scalers(
290
314
# This is needed to make sure that quantization_factor remains a repeated view of n_scaler_blocks
291
315
# For some reason the 127/scaler_absmax realizes n_scaler entries when only n_scaler_blocks are needed
292
316
# The following will grab the first entry for the n_scaler_blocks which is the same across the scaler_block_size
317
+ # pyre-fixme[16]: `float` has no attribute `__getitem__`.
293
318
quantization_factor = quantization_factor [:, 0 ]
294
319
295
320
return (
@@ -326,6 +351,7 @@ def dequantize_scalers(
326
351
327
352
@staticmethod
328
353
def convert_to_norm_float_weight (
354
+ # pyre-fixme[11]: Annotation `tensor` is not defined as a type.
329
355
inpt_tensor : torch .Tensor , n_blocks : int , block_size : int , nf4 : torch .tensor
330
356
) -> torch .Tensor :
331
357
"""Convert a tensor to the normalized float weight format"""
@@ -386,6 +412,7 @@ def get_original_weight(self) -> torch.Tensor:
386
412
387
413
@staticmethod
388
414
def quantize_tensor_nearest (
415
+ # pyre-fixme[11]: Annotation `float16` is not defined as a type.
389
416
value : torch .float16 , nf4 : torch .Tensor
390
417
) -> torch .Tensor :
391
418
"""Quantize a float16 tensor to nf4 format to nearest and not rounded up"""
@@ -396,6 +423,10 @@ def quantize_tensor_nearest(
396
423
return closest_nf4
397
424
398
425
@staticmethod
426
+ # pyre-fixme[14]: `dequantize` overrides method defined in `TensorBase`
427
+ # inconsistently.
428
+ # pyre-fixme[40]: Static method `dequantize` cannot override a non-static method
429
+ # defined in `torch._C.TensorBase`.
399
430
def dequantize (value : torch .Tensor , nf4 : torch .Tensor ) -> torch .Tensor :
400
431
"""Dequantize a nf4 value to float16 format"""
401
432
# return nf4.index_select(0, value)
@@ -406,6 +437,8 @@ def unpack(
406
437
) -> Tuple [
407
438
int , int , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Size
408
439
]:
440
+ # pyre-fixme[7]: Expected `Tuple[int, int, Tensor, Tensor, Tensor, Tensor,
441
+ # Size]` but got `Tuple[int, int, int, Tensor, Tensor, Tensor, Tensor]`.
409
442
return (
410
443
self .block_size ,
411
444
self .n_blocks ,
@@ -416,12 +449,15 @@ def unpack(
416
449
self .quantized_data ,
417
450
)
418
451
452
+ # pyre-fixme[14]: `__repr__` overrides method defined in `Tensor` inconsistently.
453
+ # pyre-fixme[3]: Return type must be annotated.
419
454
def __repr__ (self ):
420
455
return f"Quantized Data: { self .quantized_data } \n Scalers: { self .quantized_scalers } \n "
421
456
422
457
def __str__ (self ):
423
458
return f"NF4Tensor({ self .shape } , { self .block_size } )"
424
459
460
+ # pyre-fixme[3]: Return type must be annotated.
425
461
def __tensor_flatten__ (self ):
426
462
tensor_meta = SubclassTensorArgs (
427
463
self .shape ,
@@ -446,6 +482,10 @@ def __tensor_flatten__(self):
446
482
], ctx
447
483
448
484
@staticmethod
485
+ # pyre-fixme[3]: Return type must be annotated.
486
+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
487
+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
488
+ # pyre-fixme[2]: Parameter must be annotated.
449
489
def __tensor_unflatten__ (inner_tensors : Dict , metadata , outer_size , outer_stride ):
450
490
assert len (inner_tensors ) == 5 , "Expected 5 inner tensors"
451
491
return NF4Tensor (
@@ -461,17 +501,22 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride
461
501
)
462
502
463
503
504
+ # pyre-fixme[3]: Return type must be annotated.
464
505
def __str__ (self ):
465
506
return self .to (torch .float32 ).__str__ ()
466
507
467
508
@classmethod
509
+ # pyre-fixme[3]: Return type must be annotated.
510
+ # pyre-fixme[2]: Parameter must be annotated.
468
511
def __torch_dispatch__ (cls , func , types , args , kwargs = None ):
469
512
"""TODO we are not supporting torch dispatch at the moment
470
513
instead we have created a Autograd.Function to handle the linear
471
514
"""
472
515
# All ops in the NF4_OPS_TABLE expect NF4 Tensors as inputs
473
516
# And don't support mixed tensor subclasses. This will trigger the handler for
474
517
# the next type in the dispatch list
518
+ # pyre-fixme[3]: Return type must be annotated.
519
+ # pyre-fixme[2]: Parameter must be annotated.
475
520
def allowed_subclasses (type ):
476
521
return (
477
522
issubclass (cls , type )
@@ -489,16 +534,25 @@ def allowed_subclasses(type):
489
534
)
490
535
491
536
# Do not force the Float8Tensor type on the returned tensor
537
+ # pyre-fixme[4]: Attribute must be annotated.
492
538
__torch_function__ = torch ._C ._disabled_torch_function_impl
493
539
494
540
class LinearNF4 (torch .autograd .Function ):
495
541
@staticmethod
542
+ # pyre-fixme[14]: `forward` overrides method defined in `_SingleLevelFunction`
543
+ # inconsistently.
544
+ # pyre-fixme[3]: Return type must be annotated.
545
+ # pyre-fixme[2]: Parameter must be annotated.
496
546
def forward (ctx , input : torch .Tensor , weight : NF4Tensor ):
497
547
"""Save the quantized nf4 weight for backward pass"""
498
548
ctx .nf4_weight = weight
499
549
return F .linear (input , weight .get_original_weight ())
500
550
501
551
@staticmethod
552
+ # pyre-fixme[14]: `backward` overrides method defined in `_SingleLevelFunction`
553
+ # inconsistently.
554
+ # pyre-fixme[3]: Return type must be annotated.
555
+ # pyre-fixme[2]: Parameter must be annotated.
502
556
def backward (ctx , grad_output ):
503
557
"""The nf4 weight will never require grad so we can just return the grad_output @ weight.get_original_weight()"""
504
558
weight : NF4Tensor = ctx .nf4_weight
@@ -514,6 +568,8 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor:
514
568
"""
515
569
return LinearNF4 .apply (input , weight )
516
570
571
+ # pyre-fixme[3]: Return type must be annotated.
572
+ # pyre-fixme[2]: Parameter must be annotated.
517
573
def to_nf4 (tensor ,
518
574
block_size : int = 64 ,
519
575
scaler_block_size : int = 256 ):
0 commit comments