Skip to content

Commit 9b25ecc

Browse files
authored
Remove args and kwargs from AffineQuantizedTensor (#247)
Summary: att Test Plan: python test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags:
1 parent 9dbdb2b commit 9b25ecc

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

torchao/quantization/subclass.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -670,18 +670,18 @@ def __new__(
670670
quant_max: Optional[int] = None,
671671
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
672672
dtype=None,
673-
# TODO: remove args and kwargs
674-
*args,
675-
**kwargs
673+
strides=None,
676674
):
675+
kwargs = {}
677676
kwargs["device"] = int_data.device
678677
kwargs["layout"] = (
679678
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
680679
)
681680
if dtype is None:
682681
dtype = scale.dtype
683682
kwargs["dtype"] = dtype
684-
assert not kwargs.get("requires_grad", False)
683+
if strides is not None:
684+
kwargs["strides"] = strides
685685
kwargs["requires_grad"] = False
686686
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
687687

@@ -696,8 +696,7 @@ def __init__(
696696
quant_max: Optional[int] = None,
697697
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
698698
dtype=None,
699-
*args,
700-
**kwargs
699+
strides=None,
701700
):
702701
self.int_data = int_data
703702
self.scale = scale
@@ -912,6 +911,7 @@ def _apply_fn_to_data(self, fn):
912911
self.quant_max,
913912
self.zero_point_domain,
914913
dtype=self.dtype,
914+
strides=self.stride(),
915915
)
916916

917917
@classmethod

0 commit comments

Comments
 (0)