Skip to content

[XNNPACK] Serialize weights as fp16 rather than fp32 #9753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def get_serialized_dtype(
self,
quant_params: Optional[QuantParams],
node: torch.fx.Node,
fp32_static_weight: bool = False,
force_fp32: bool = False,
) -> XNNDatatype:
# Default initialization
dtype = XNNDatatype.xnn_datatype_fp32
Expand Down Expand Up @@ -267,7 +267,7 @@ def get_per_channel_dtype(
if node_dtype is not None and node_dtype == torch.float16:
dtype = (
XNNDatatype.xnn_datatype_fp32
if fp32_static_weight
if force_fp32
else XNNDatatype.xnn_datatype_fp16
)

Expand Down Expand Up @@ -348,7 +348,7 @@ def define_tensor( # noqa: C901
convert_to_nhwc: bool = False,
swap_in_out_for_weights: bool = False,
quant_params: Optional[QuantParams] = None,
fp32_static_weights: bool = False,
force_fp32: bool = False,
groups: int = 1,
) -> None:
"""
Expand All @@ -368,7 +368,7 @@ def define_tensor( # noqa: C901
constant data. If used along with convert_to_nhwc, this
swap will happen before converting to nhwc.
quant_params: Quantization meta data for this tensor, None if it is not quantized
fp32_static_weights: XNN_FLAG_FP32_STATIC_WEIGHTS for fp16 conv
force_fp32: forces tensor to be serialize as fp32, used for bias of dynamically quantized ops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/fp32_static_weight/force_fp32 - seems a little too vague if you ask me.

groups: number of groups for swap_in_out_for_weights
"""

Expand Down Expand Up @@ -405,7 +405,7 @@ def define_tensor( # noqa: C901
convert_to_nhwc,
swap_in_out_for_weights,
quant_params,
fp32_static_weights,
force_fp32,
groups,
)

Expand All @@ -417,9 +417,7 @@ def define_tensor( # noqa: C901
check_or_raise(len(dims) == 4, "Converting to nhwc requires 4d tensor")
dims = [dims[i] for i in PERM_NCHW_TO_NHWC]

dtype = self.get_serialized_dtype(
quant_params, tensor, fp32_static_weight=fp32_static_weights
)
dtype = self.get_serialized_dtype(quant_params, tensor, force_fp32=force_fp32)

tvalue = XNNTensorValue(
datatype=dtype,
Expand Down Expand Up @@ -504,7 +502,7 @@ def get_serialized_buffer_index(
convert_to_nhwc: bool,
swap_in_out_for_weights: bool,
quant_params: Optional[QuantParams],
fp32_static_weights: bool = False,
force_fp32: bool = False,
groups: int = 1,
) -> int:
"""
Expand All @@ -525,7 +523,7 @@ def get_serialized_buffer_index(
constant data. If used along with convert_to_nhwc, this
swap will happen before converting to nhwc.
quant_params: Quantization meta data for this tensor, None if it is not quantize
fp32_static_weights: bool to indicate whether tensor is fp32 static weights
force_fp32: bool to indicate whether tensor is fp32 static weights
groups: groups for swap_in_out_for_weights

Returns:
Expand Down Expand Up @@ -554,7 +552,7 @@ def get_serialized_buffer_index(
# Quantize buffer if static data is indeed quantized
if quant_params is not None and not quant_params.is_dynamic:
const_val = quant_params.quantize_tensor(const_val).contiguous()
elif const_val.dtype != torch.float16 or fp32_static_weights:
elif const_val.dtype != torch.float16 or force_fp32:
# ensure that the const is fp32
const_val = const_val.to(dtype=torch.float32).contiguous()

Expand Down
6 changes: 3 additions & 3 deletions backends/xnnpack/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def define_node(
weight_quant_params = QuantParams.from_weights(
kernel_node, self._exported_program
)
fp32_static_weights = kernel_node.meta["val"].dtype == torch.float16

if weight_quant_params is not None and weight_quant_params.per_channel:
if is_transpose:
Expand All @@ -102,8 +101,8 @@ def define_node(
convert_to_nhwc=True,
swap_in_out_for_weights=is_depthwise_conv or is_transpose,
quant_params=weight_quant_params,
fp32_static_weights=fp32_static_weights,
groups=groups if is_transpose else 1,
force_fp32=True,
)
kwargs["filter_id"] = vals_to_ids[get_input_node(node, 1)]

Expand All @@ -127,13 +126,14 @@ def define_node(
bias_quant_params = QuantParams.from_bias(
bias_node, weight_quant_params, input_quant_params
)

self.define_tensor(
get_input_node(node, 2),
xnn_graph,
vals_to_ids,
convert_to_nhwc=False,
quant_params=bias_quant_params,
fp32_static_weights=fp32_static_weights,
force_fp32=True,
)
kwargs["bias_id"] = vals_to_ids[get_input_node(node, 2)]

Expand Down
9 changes: 7 additions & 2 deletions backends/xnnpack/operators/op_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def define_node(
xnn_graph,
vals_to_ids,
quant_params=weight_quant_params,
fp32_static_weights=True,
)
filter_id = vals_to_ids[weight_node]

Expand All @@ -69,12 +68,18 @@ def define_node(
bias_quant_params = QuantParams.from_bias(
bias_node, weight_quant_params, input_quant_params
)
# For dynamic quantization, there are no kernels with fp16 bias
# So we need to force the fp16 bias to fp32
force_fp32 = False
if input_quant_params is not None and input_quant_params.is_dynamic:
force_fp32 = True

self.define_tensor(
get_input_node(node, 2),
xnn_graph,
vals_to_ids,
quant_params=bias_quant_params,
fp32_static_weights=True,
force_fp32=force_fp32,
)
bias_id = vals_to_ids[bias_node]
else:
Expand Down
8 changes: 3 additions & 5 deletions backends/xnnpack/test/ops/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,9 +605,7 @@ def _test_qd8_linear_per_tensor_unsupported(self, dtype: torch.dtype = torch.flo

if legacy_partitioner:
tester.to_edge()
tester.partition(
Partition(DynamicallyQuantizedPartitioner)
).dump_artifact()
tester.partition(Partition(DynamicallyQuantizedPartitioner))
# should have [add]mm node
if uses_bias:
tester.check(
Expand All @@ -624,7 +622,7 @@ def _test_qd8_linear_per_tensor_unsupported(self, dtype: torch.dtype = torch.flo
else:
tester.to_edge_transform_and_lower(
ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner])
).dump_artifact()
)
# should not have a delegate node
tester.check_not(
[
Expand Down Expand Up @@ -717,7 +715,7 @@ def test_fp16_linear(self):
num_batch_dims=num_batch_dims,
uses_bias=use_bias,
dtype=torch.float16,
atol=5e-2, # TODO(T212995726): Investigate right atol for rand[n] inputs
atol=5e-3, # TODO(T212995726): Investigate right atol for rand[n] inputs
)

def test_fp32_linear(self):
Expand Down
Loading