Skip to content

Commit fbd88a7

Browse files
generatedunixname89002005307016facebook-github-bot
generatedunixname89002005307016
authored andcommitted
suppress errors in pytorch
Differential Revision: D54963973 fbshipit-source-id: 0f3fb06bcc140c22ec99089d927a2641ced9aa37
1 parent 202d542 commit fbd88a7

File tree

12 files changed

+563
-4
lines changed

12 files changed

+563
-4
lines changed

torchao/dtypes/nf4tensor.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66
import torch.nn.functional as F
77

88

9+
# pyre-fixme[5]: Global expression must be annotated.
910
aten = torch.ops.aten
11+
# pyre-fixme[5]: Global expression must be annotated.
1012
c10d_functional = torch.ops.c10d_functional
1113

1214
from typing import Any
15+
# pyre-fixme[5]: Global annotation cannot contain `Any`.
1316
NF4_OPS_TABLE: Dict[Any, Any] = {}
1417

1518

19+
# pyre-fixme[3]: Return type must be annotated.
1620
def same_metadata(a: "NF4Tensor", b: "NF4Tensor"):
1721
both_nf4 = isinstance(a, NF4Tensor) and isinstance(b, NF4Tensor)
1822
return (
@@ -22,9 +26,14 @@ def same_metadata(a: "NF4Tensor", b: "NF4Tensor"):
2226
and a.n_blocks == b.n_blocks
2327
)
2428

29+
# pyre-fixme[3]: Return type must be annotated.
30+
# pyre-fixme[2]: Parameter must be annotated.
2531
def implements(aten_ops):
2632
"""Use this decorator to implement a function for an aten op in __torch_dispatch__"""
2733

34+
# pyre-fixme[53]: Captured variable `aten_ops` is not annotated.
35+
# pyre-fixme[3]: Return type must be annotated.
36+
# pyre-fixme[2]: Parameter must be annotated.
2837
def decorator(func):
2938
for op in aten_ops:
3039
NF4_OPS_TABLE[op] = func
@@ -33,14 +42,20 @@ def decorator(func):
3342
return decorator
3443

3544
@implements([torch.ops.aten.detach.default, torch.ops.aten.detach])
45+
# pyre-fixme[3]: Return type must be annotated.
46+
# pyre-fixme[2]: Parameter must be annotated.
3647
def noop_detach(func, *args, **kwargs):
3748
return args[0][0]
3849

3950
@implements([torch.ops.aten._to_copy.default])
51+
# pyre-fixme[3]: Return type must be annotated.
52+
# pyre-fixme[2]: Parameter must be annotated.
4053
def _to_copy(func, *args, **kwargs):
4154
return args[0][0].get_original_weight().to(args[1]['dtype'])
4255

4356
@implements([torch.ops.aten.to.dtype])
57+
# pyre-fixme[3]: Return type must be annotated.
58+
# pyre-fixme[2]: Parameter must be annotated.
4459
def to_dtype(func, *args, **kwargs):
4560
return args[0][0].get_original_weight().to(args[0][1])
4661

@@ -50,11 +65,14 @@ def to_dtype(func, *args, **kwargs):
5065
aten.copy_.default,
5166
]
5267
)
68+
# pyre-fixme[3]: Return type must be annotated.
69+
# pyre-fixme[2]: Parameter must be annotated.
5370
def copy_(func, *args, **kwargs):
5471
original: NF4Tensor = args[0][0]
5572
copy_in: torch.Tensor = args[0][1]
5673

5774
# Base Case
75+
# pyre-fixme[6]: For 2nd argument expected `NF4Tensor` but got `Tensor`.
5876
if same_metadata(original, copy_in):
5977
original_tensors = original.__tensor_flatten__()[0]
6078
for tensor_name in original_tensors:
@@ -76,6 +94,7 @@ def copy_(func, *args, **kwargs):
7694
@dataclass
7795
class SubclassTensorArgs:
7896
original_shape: torch.Size
97+
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
7998
original_strides: Tuple
8099
storage_offset: int
81100
dtype: torch.dtype
@@ -106,6 +125,7 @@ def get_block_absmax(inpt_tensor: torch.Tensor, block_size: int) -> torch.Tensor
106125
class NF4Tensor(torch.Tensor):
107126
"""NF4Tensor class for converting a weight to the QLoRA NF4 format"""
108127

128+
# pyre-fixme[3]: Return type must be annotated.
109129
def __new__(
110130
cls,
111131
# Args related for base tensor construction
@@ -134,6 +154,7 @@ def __new__(
134154
135155
"""
136156

157+
# pyre-fixme[16]: `Tensor` has no attribute `_make_wrapper_subclass`.
137158
nf4tensor = torch.Tensor._make_wrapper_subclass(
138159
cls,
139160
tensor_meta.original_shape,
@@ -145,6 +166,7 @@ def __new__(
145166
)
146167
return nf4tensor
147168

169+
# pyre-fixme[3]: Return type must be annotated.
148170
def __init__(
149171
self,
150172
tensor_meta: SubclassTensorArgs,
@@ -169,6 +191,7 @@ def __init__(
169191

170192
@classmethod
171193
@torch.no_grad()
194+
# pyre-fixme[3]: Return type must be annotated.
172195
def from_tensor(
173196
cls,
174197
inpt_tensor: torch.Tensor,
@@ -281,6 +304,7 @@ def double_quantize_scalers(
281304
n_scaler_blocks, scaler_block_size
282305
)
283306

307+
# pyre-fixme[58]: `/` is not supported for operand types `int` and `Tensor`.
284308
quantization_factor = 256 / (2 * scaler_absmax)
285309
# Length equal to weight numel // block_size
286310
quantized_scaler_blocks = scaler_blocks * quantization_factor
@@ -290,6 +314,7 @@ def double_quantize_scalers(
290314
# This is needed to make sure that quantization_factor remains a repeated view of n_scaler_blocks
291315
# For some reason the 127/scaler_absmax realizes n_scaler entries when only n_scaler_blocks are needed
292316
# The following will grab the first entry for the n_scaler_blocks which is the same across the scaler_block_size
317+
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
293318
quantization_factor = quantization_factor[:, 0]
294319

295320
return (
@@ -326,6 +351,7 @@ def dequantize_scalers(
326351

327352
@staticmethod
328353
def convert_to_norm_float_weight(
354+
# pyre-fixme[11]: Annotation `tensor` is not defined as a type.
329355
inpt_tensor: torch.Tensor, n_blocks: int, block_size: int, nf4: torch.tensor
330356
) -> torch.Tensor:
331357
"""Convert a tensor to the normalized float weight format"""
@@ -386,6 +412,7 @@ def get_original_weight(self) -> torch.Tensor:
386412

387413
@staticmethod
388414
def quantize_tensor_nearest(
415+
# pyre-fixme[11]: Annotation `float16` is not defined as a type.
389416
value: torch.float16, nf4: torch.Tensor
390417
) -> torch.Tensor:
391418
"""Quantize a float16 tensor to nf4 format to nearest and not rounded up"""
@@ -396,6 +423,10 @@ def quantize_tensor_nearest(
396423
return closest_nf4
397424

398425
@staticmethod
426+
# pyre-fixme[14]: `dequantize` overrides method defined in `TensorBase`
427+
# inconsistently.
428+
# pyre-fixme[40]: Static method `dequantize` cannot override a non-static method
429+
# defined in `torch._C.TensorBase`.
399430
def dequantize(value: torch.Tensor, nf4: torch.Tensor) -> torch.Tensor:
400431
"""Dequantize a nf4 value to float16 format"""
401432
# return nf4.index_select(0, value)
@@ -406,6 +437,8 @@ def unpack(
406437
) -> Tuple[
407438
int, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Size
408439
]:
440+
# pyre-fixme[7]: Expected `Tuple[int, int, Tensor, Tensor, Tensor, Tensor,
441+
# Size]` but got `Tuple[int, int, int, Tensor, Tensor, Tensor, Tensor]`.
409442
return (
410443
self.block_size,
411444
self.n_blocks,
@@ -416,12 +449,15 @@ def unpack(
416449
self.quantized_data,
417450
)
418451

452+
# pyre-fixme[14]: `__repr__` overrides method defined in `Tensor` inconsistently.
453+
# pyre-fixme[3]: Return type must be annotated.
419454
def __repr__(self):
420455
return f"Quantized Data: {self.quantized_data}\nScalers: {self.quantized_scalers}\n"
421456

422457
def __str__(self):
423458
return f"NF4Tensor({self.shape}, {self.block_size})"
424459

460+
# pyre-fixme[3]: Return type must be annotated.
425461
def __tensor_flatten__(self):
426462
tensor_meta = SubclassTensorArgs(
427463
self.shape,
@@ -446,6 +482,10 @@ def __tensor_flatten__(self):
446482
], ctx
447483

448484
@staticmethod
485+
# pyre-fixme[3]: Return type must be annotated.
486+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
487+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
488+
# pyre-fixme[2]: Parameter must be annotated.
449489
def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride):
450490
assert len(inner_tensors) == 5, "Expected 5 inner tensors"
451491
return NF4Tensor(
@@ -461,17 +501,22 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride
461501
)
462502

463503

504+
# pyre-fixme[3]: Return type must be annotated.
464505
def __str__(self):
465506
return self.to(torch.float32).__str__()
466507

467508
@classmethod
509+
# pyre-fixme[3]: Return type must be annotated.
510+
# pyre-fixme[2]: Parameter must be annotated.
468511
def __torch_dispatch__(cls, func, types, args, kwargs=None):
469512
"""TODO we are not supporting torch dispatch at the moment
470513
instead we have created a Autograd.Function to handle the linear
471514
"""
472515
# All ops in the NF4_OPS_TABLE expect NF4 Tensors as inputs
473516
# And don't support mixed tensor subclasses. This will trigger the handler for
474517
# the next type in the dispatch list
518+
# pyre-fixme[3]: Return type must be annotated.
519+
# pyre-fixme[2]: Parameter must be annotated.
475520
def allowed_subclasses(type):
476521
return (
477522
issubclass(cls, type)
@@ -489,16 +534,25 @@ def allowed_subclasses(type):
489534
)
490535

491536
# Do not force the Float8Tensor type on the returned tensor
537+
# pyre-fixme[4]: Attribute must be annotated.
492538
__torch_function__ = torch._C._disabled_torch_function_impl
493539

494540
class LinearNF4(torch.autograd.Function):
495541
@staticmethod
542+
# pyre-fixme[14]: `forward` overrides method defined in `_SingleLevelFunction`
543+
# inconsistently.
544+
# pyre-fixme[3]: Return type must be annotated.
545+
# pyre-fixme[2]: Parameter must be annotated.
496546
def forward(ctx, input: torch.Tensor, weight: NF4Tensor):
497547
"""Save the quantized nf4 weight for backward pass"""
498548
ctx.nf4_weight = weight
499549
return F.linear(input, weight.get_original_weight())
500550

501551
@staticmethod
552+
# pyre-fixme[14]: `backward` overrides method defined in `_SingleLevelFunction`
553+
# inconsistently.
554+
# pyre-fixme[3]: Return type must be annotated.
555+
# pyre-fixme[2]: Parameter must be annotated.
502556
def backward(ctx, grad_output):
503557
"""The nf4 weight will never require grad so we can just return the grad_output @ weight.get_original_weight()"""
504558
weight: NF4Tensor = ctx.nf4_weight
@@ -514,6 +568,8 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor:
514568
"""
515569
return LinearNF4.apply(input, weight)
516570

571+
# pyre-fixme[3]: Return type must be annotated.
572+
# pyre-fixme[2]: Parameter must be annotated.
517573
def to_nf4(tensor,
518574
block_size: int = 64,
519575
scaler_block_size: int = 256):

0 commit comments

Comments
 (0)