Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,6 @@ def dequantize_affine(context, node):
int_data.astype(quantized_np_dtype),
zero_point,
scale,
axis=-1,
name=node.name,
)
context.add(output, node.name)
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,52 @@ def forward(self, x):
prog = res[1]._mil_program
assert get_op_types_in_program(prog) == ["constexpr_blockwise_shift_scale", "linear"]

@pytest.mark.skipif(not _HAS_TORCHAO, reason=MSG_TORCHAO_NOT_FOUND)
@pytest.mark.parametrize(
"compute_unit, has_zeros",
itertools.product(compute_units, [True, False], [ct.target.IOS16, ct.target.IOS17]),
)
def test_dequantize_affine_before_ios18(self, compute_unit, has_zeros, minimum_deployment_target):

quant_min = -128
quant_max = 127

n = 4
k = 128
input_dtype = torch.int8
int_data = torch.randint(low=quant_min, high=quant_max, size=(n, k)).to(input_dtype)
scale = torch.rand(n, 1)

zero_point = None
if has_zeros:
zero_point = torch.randint(low=quant_min, high=quant_max, size=(n, 1)).to(input_dtype)

class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("int_data", int_data)
self.register_buffer("scale", scale)
self.register_buffer("zero_point", zero_point)

def forward(self, x):
w = torchao_quant.dequantize_affine(self.int_data, [1, k], self.scale, self.zero_point, input_dtype, quant_min, quant_max)
return torch.nn.functional.linear(x, w)


model = Model()
model = model.to(torch.device("cpu"))

input_shape = [(3, k)]
res = self.run_compare_torch(
input_shape,
model,
minimum_deployment_target=minimum_deployment_target,
compute_unit=compute_unit,
rtol=0.1,
frontend=TorchFrontend.TORCHEXPORT,
)
prog = res[1]._mil_program
assert get_op_types_in_program(prog) == ["constexpr_affine_dequantize", "linear"]


# TODO(rdar://108463675): refactor torch op tests later to parametrize quantized vs standard ops
Expand Down