Skip to content

Commit 59c7311

Browse files
authored
[AFQ] Optimize tensor_flatten for runtime (#1951)
[ghstack-poisoned]
1 parent dfbd681 commit 59c7311

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,18 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
165165
return dq
166166

167167
def __tensor_flatten__(self):
168-
return ["tensor_impl"], [
169-
self.block_size,
170-
self.shape,
171-
self.quant_min,
172-
self.quant_max,
173-
self.zero_point_domain,
174-
self.dtype,
175-
]
168+
# This is used in rumtime to unwrap AffineQuantizedTensor activations.
169+
# AffineQuantizedTensor has __torch_function__ override:
170+
# Each getattr will go through it, which is up to 10x slower than default attribute access.
171+
with torch._C.DisableTorchFunctionSubclass():
172+
return ["tensor_impl"], [
173+
self.block_size,
174+
self.shape,
175+
self.quant_min,
176+
self.quant_max,
177+
self.zero_point_domain,
178+
self.dtype,
179+
]
176180

177181
@classmethod
178182
def __tensor_unflatten__(

0 commit comments

Comments
 (0)