Skip to content

Commit 185236c

Browse files
authored
Much lint, so wow (#76)
1 parent f7e12c8 commit 185236c

File tree

13 files changed

+303
-632
lines changed

13 files changed

+303
-632
lines changed

CODEOWNERS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
msaroufim
2+
cpuhrsch

torchao/dtypes/nf4tensor.py

Lines changed: 53 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -2,96 +2,87 @@
22
from typing import Dict, Tuple
33

44
import torch
5-
from torch import Tensor
65
import torch.nn.functional as F
6+
from torch import Tensor
77

88

9-
# pyre-fixme[5]: Global expression must be annotated.
109
aten = torch.ops.aten
11-
# pyre-fixme[5]: Global expression must be annotated.
10+
1211
c10d_functional = torch.ops.c10d_functional
1312

1413
from typing import Any
15-
# pyre-fixme[5]: Global annotation cannot contain `Any`.
14+
1615
NF4_OPS_TABLE: Dict[Any, Any] = {}
1716

1817

19-
# pyre-fixme[3]: Return type must be annotated.
2018
def same_metadata(a: "NF4Tensor", b: "NF4Tensor"):
2119
both_nf4 = isinstance(a, NF4Tensor) and isinstance(b, NF4Tensor)
2220
return (
23-
both_nf4 and
24-
a.block_size == b.block_size
21+
both_nf4
22+
and a.block_size == b.block_size
2523
and a.scaler_block_size == b.scaler_block_size
2624
and a.n_blocks == b.n_blocks
2725
)
2826

29-
# pyre-fixme[3]: Return type must be annotated.
30-
# pyre-fixme[2]: Parameter must be annotated.
27+
3128
def implements(aten_ops):
3229
"""Use this decorator to implement a function for an aten op in __torch_dispatch__"""
3330

34-
# pyre-fixme[53]: Captured variable `aten_ops` is not annotated.
35-
# pyre-fixme[3]: Return type must be annotated.
36-
# pyre-fixme[2]: Parameter must be annotated.
3731
def decorator(func):
3832
for op in aten_ops:
3933
NF4_OPS_TABLE[op] = func
4034
return func
4135

4236
return decorator
4337

38+
4439
@implements([torch.ops.aten.detach.default, torch.ops.aten.detach])
45-
# pyre-fixme[3]: Return type must be annotated.
46-
# pyre-fixme[2]: Parameter must be annotated.
4740
def noop_detach(func, *args, **kwargs):
4841
return args[0][0]
4942

43+
5044
@implements([torch.ops.aten._to_copy.default])
51-
# pyre-fixme[3]: Return type must be annotated.
52-
# pyre-fixme[2]: Parameter must be annotated.
5345
def _to_copy(func, *args, **kwargs):
5446
if not args[0][0].is_contiguous():
5547
assert args[0][0].t().is_contiguous()
5648
return func(args[0][0].t()).t()
57-
return args[0][0].get_original_weight().to(args[1]['dtype'])
49+
return args[0][0].get_original_weight().to(args[1]["dtype"])
50+
5851

5952
@implements([torch.ops.aten.to.dtype])
60-
# pyre-fixme[3]: Return type must be annotated.
61-
# pyre-fixme[2]: Parameter must be annotated.
6253
def to_dtype(func, *args, **kwargs):
6354
if not args[0][0].is_contiguous():
6455
assert args[0][0].t().is_contiguous()
6556
return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t()
6657
return args[0][0].get_original_weight().to(args[0][1])
6758

59+
6860
@implements([torch.ops.aten.t.default])
69-
# pyre-fixme[3]: Return type must be annotated.
70-
# pyre-fixme[2]: Parameter must be annotated.
7161
def t_default(func, *args, **kwargs):
7262
a = args[0][0]
7363
tensor_meta = SubclassTensorArgs(
74-
a.size(),
75-
(a.stride(1), a.stride(0)),
76-
a.storage_offset(),
77-
torch.bits2x4,
78-
a.device,
79-
a.requires_grad)
64+
a.size(),
65+
(a.stride(1), a.stride(0)),
66+
a.storage_offset(),
67+
torch.bits2x4,
68+
a.device,
69+
a.requires_grad,
70+
)
8071
b = NF4Tensor(
81-
tensor_meta,
82-
a.block_size,
83-
a.n_blocks,
84-
a.scaler_block_size,
85-
a.quantized_scalers,
86-
a.quantization_factor,
87-
a.scaler_mean,
88-
a.quantized_data,
89-
a.nf4)
72+
tensor_meta,
73+
a.block_size,
74+
a.n_blocks,
75+
a.scaler_block_size,
76+
a.quantized_scalers,
77+
a.quantization_factor,
78+
a.scaler_mean,
79+
a.quantized_data,
80+
a.nf4,
81+
)
9082
return b
9183

84+
9285
@implements([torch.ops.aten.mm.default])
93-
# pyre-fixme[3]: Return type must be annotated.
94-
# pyre-fixme[2]: Parameter must be annotated.
9586
def mm_default(func, *args, **kwargs):
9687
return linear_nf4(args[0][0], args[0][1])
9788

@@ -101,14 +92,12 @@ def mm_default(func, *args, **kwargs):
10192
aten.copy_.default,
10293
]
10394
)
104-
# pyre-fixme[3]: Return type must be annotated.
105-
# pyre-fixme[2]: Parameter must be annotated.
10695
def copy_(func, *args, **kwargs):
10796
original: NF4Tensor = args[0][0]
10897
copy_in: torch.Tensor = args[0][1]
10998

11099
# Base Case
111-
# pyre-fixme[6]: For 2nd argument expected `NF4Tensor` but got `Tensor`.
100+
112101
if same_metadata(original, copy_in):
113102
original_tensors = original.__tensor_flatten__()[0]
114103
for tensor_name in original_tensors:
@@ -117,7 +106,9 @@ def copy_(func, *args, **kwargs):
117106

118107
# Convert Non NF4Tensor into NF4 for copy in
119108
if not isinstance(copy_in, NF4Tensor):
120-
copy_in_nf4 = NF4Tensor.from_tensor(copy_in, original.block_size, original.scaler_block_size)
109+
copy_in_nf4 = NF4Tensor.from_tensor(
110+
copy_in, original.block_size, original.scaler_block_size
111+
)
121112
return original.copy_(copy_in_nf4)
122113

123114
# Other Tensor is not a NF4Tensor
@@ -127,10 +118,11 @@ def copy_(func, *args, **kwargs):
127118
)
128119
return original.copy_(same_meta_nf4)
129120

121+
130122
@dataclass
131123
class SubclassTensorArgs:
132124
original_shape: torch.Size
133-
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
125+
134126
original_strides: Tuple
135127
storage_offset: int
136128
dtype: torch.dtype
@@ -161,7 +153,6 @@ def get_block_absmax(inpt_tensor: torch.Tensor, block_size: int) -> torch.Tensor
161153
class NF4Tensor(torch.Tensor):
162154
"""NF4Tensor class for converting a weight to the QLoRA NF4 format"""
163155

164-
# pyre-fixme[3]: Return type must be annotated.
165156
def __new__(
166157
cls,
167158
# Args related for base tensor construction
@@ -190,7 +181,6 @@ def __new__(
190181
191182
"""
192183

193-
# pyre-fixme[16]: `Tensor` has no attribute `_make_wrapper_subclass`.
194184
nf4tensor = torch.Tensor._make_wrapper_subclass(
195185
cls,
196186
tensor_meta.original_shape,
@@ -203,7 +193,6 @@ def __new__(
203193
)
204194
return nf4tensor
205195

206-
# pyre-fixme[3]: Return type must be annotated.
207196
def __init__(
208197
self,
209198
tensor_meta: SubclassTensorArgs,
@@ -228,7 +217,6 @@ def __init__(
228217

229218
@classmethod
230219
@torch.no_grad()
231-
# pyre-fixme[3]: Return type must be annotated.
232220
def from_tensor(
233221
cls,
234222
inpt_tensor: torch.Tensor,
@@ -342,7 +330,6 @@ def double_quantize_scalers(
342330
n_scaler_blocks, scaler_block_size
343331
)
344332

345-
# pyre-fixme[58]: `/` is not supported for operand types `int` and `Tensor`.
346333
quantization_factor = 256 / (2 * scaler_absmax)
347334
# Length equal to weight numel // block_size
348335
quantized_scaler_blocks = scaler_blocks * quantization_factor
@@ -352,7 +339,7 @@ def double_quantize_scalers(
352339
# This is needed to make sure that quantization_factor remains a repeated view of n_scaler_blocks
353340
# For some reason the 127/scaler_absmax realizes n_scaler entries when only n_scaler_blocks are needed
354341
# The following will grab the first entry for the n_scaler_blocks which is the same across the scaler_block_size
355-
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
342+
356343
quantization_factor = quantization_factor[:, 0]
357344

358345
return (
@@ -389,7 +376,6 @@ def dequantize_scalers(
389376

390377
@staticmethod
391378
def convert_to_norm_float_weight(
392-
# pyre-fixme[11]: Annotation `tensor` is not defined as a type.
393379
inpt_tensor: torch.Tensor, n_blocks: int, block_size: int, nf4: torch.tensor
394380
) -> torch.Tensor:
395381
"""Convert a tensor to the normalized float weight format"""
@@ -450,7 +436,6 @@ def get_original_weight(self) -> torch.Tensor:
450436

451437
@staticmethod
452438
def quantize_tensor_nearest(
453-
# pyre-fixme[11]: Annotation `float16` is not defined as a type.
454439
value: torch.float16, nf4: torch.Tensor
455440
) -> torch.Tensor:
456441
"""Quantize a float16 tensor to nf4 format to nearest and not rounded up"""
@@ -461,9 +446,9 @@ def quantize_tensor_nearest(
461446
return closest_nf4
462447

463448
@staticmethod
464-
# pyre-fixme[14]: `dequantize` overrides method defined in `TensorBase`
449+
465450
# inconsistently.
466-
# pyre-fixme[40]: Static method `dequantize` cannot override a non-static method
451+
467452
# defined in `torch._C.TensorBase`.
468453
def dequantize(value: torch.Tensor, nf4: torch.Tensor) -> torch.Tensor:
469454
"""Dequantize a nf4 value to bfloat16 format"""
@@ -475,7 +460,7 @@ def unpack(
475460
) -> Tuple[
476461
int, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Size
477462
]:
478-
# pyre-fixme[7]: Expected `Tuple[int, int, Tensor, Tensor, Tensor, Tensor,
463+
479464
# Size]` but got `Tuple[int, int, int, Tensor, Tensor, Tensor, Tensor]`.
480465
return (
481466
self.block_size,
@@ -487,15 +472,12 @@ def unpack(
487472
self.quantized_data,
488473
)
489474

490-
# pyre-fixme[14]: `__repr__` overrides method defined in `Tensor` inconsistently.
491-
# pyre-fixme[3]: Return type must be annotated.
492475
def __repr__(self):
493476
return f"Quantized Data: {self.quantized_data}\nScalers: {self.quantized_scalers}\n"
494477

495478
def __str__(self):
496479
return f"NF4Tensor({self.shape}, {self.block_size})"
497480

498-
# pyre-fixme[3]: Return type must be annotated.
499481
def __tensor_flatten__(self):
500482
tensor_meta = SubclassTensorArgs(
501483
self.shape,
@@ -520,10 +502,9 @@ def __tensor_flatten__(self):
520502
], ctx
521503

522504
@staticmethod
523-
# pyre-fixme[3]: Return type must be annotated.
524-
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
505+
525506
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
526-
# pyre-fixme[2]: Parameter must be annotated.
507+
527508
def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride):
528509
assert len(inner_tensors) == 5, "Expected 5 inner tensors"
529510
return NF4Tensor(
@@ -538,28 +519,25 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride
538519
inner_tensors["nf4"],
539520
)
540521

541-
542-
# pyre-fixme[3]: Return type must be annotated.
543522
def __str__(self):
544523
return self.to(torch.float32).__str__()
545524

546525
@classmethod
547-
# pyre-fixme[3]: Return type must be annotated.
548-
# pyre-fixme[2]: Parameter must be annotated.
549526
def __torch_dispatch__(cls, func, types, args, kwargs=None):
550527
"""TODO we are not supporting torch dispatch at the moment
551528
instead we have created a Autograd.Function to handle the linear
552529
"""
553530
# All ops in the NF4_OPS_TABLE expect NF4 Tensors as inputs
554531
# And don't support mixed tensor subclasses. This will trigger the handler for
555532
# the next type in the dispatch list
556-
# pyre-fixme[3]: Return type must be annotated.
557-
# pyre-fixme[2]: Parameter must be annotated.
533+
558534
def allowed_subclasses(type):
559535
return (
560536
issubclass(cls, type)
561537
or issubclass(torch._subclasses.fake_tensor.FakeTensor, type)
562-
or issubclass(torch._subclasses.functional_tensor.FunctionalTensor, type)
538+
or issubclass(
539+
torch._subclasses.functional_tensor.FunctionalTensor, type
540+
)
563541
)
564542

565543
if not all(allowed_subclasses(t) for t in types):
@@ -572,25 +550,24 @@ def allowed_subclasses(type):
572550
)
573551

574552
# Do not force the Float8Tensor type on the returned tensor
575-
# pyre-fixme[4]: Attribute must be annotated.
553+
576554
__torch_function__ = torch._C._disabled_torch_function_impl
577555

556+
578557
class LinearNF4(torch.autograd.Function):
579558
@staticmethod
580-
# pyre-fixme[14]: `forward` overrides method defined in `_SingleLevelFunction`
559+
581560
# inconsistently.
582-
# pyre-fixme[3]: Return type must be annotated.
583-
# pyre-fixme[2]: Parameter must be annotated.
561+
584562
def forward(ctx, input: torch.Tensor, weight: NF4Tensor):
585563
"""Save the quantized nf4 weight for backward pass"""
586564
ctx.nf4_weight = weight
587565
return F.linear(input, weight.to(input.dtype))
588566

589567
@staticmethod
590-
# pyre-fixme[14]: `backward` overrides method defined in `_SingleLevelFunction`
568+
591569
# inconsistently.
592-
# pyre-fixme[3]: Return type must be annotated.
593-
# pyre-fixme[2]: Parameter must be annotated.
570+
594571
def backward(ctx, grad_output):
595572
"""The nf4 weight will never require grad so we can just return the grad_output @ weight.get_original_weight()"""
596573
weight: NF4Tensor = ctx.nf4_weight
@@ -606,10 +583,7 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor:
606583
"""
607584
return LinearNF4.apply(input, weight)
608585

609-
# pyre-fixme[3]: Return type must be annotated.
610-
# pyre-fixme[2]: Parameter must be annotated.
611-
def to_nf4(tensor,
612-
block_size: int = 64,
613-
scaler_block_size: int = 256):
586+
587+
def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256):
614588
tensor1 = tensor.to(torch.bfloat16)
615589
return NF4Tensor.from_tensor(tensor1, block_size, scaler_block_size)

0 commit comments

Comments
 (0)