Skip to content

[ExecuTorch][XNNPACK] Rename linear weight partitioning flag for clarity #8934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions backends/xnnpack/partition/config/gemm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:
def _overwrite_precision(self, node: torch.fx.Node):
precision = self._detect_precision(node)
if precision not in self.enabled_precision_types:
# detected precision is not enabled, lets try to partition it as fp32
# detected precision is not enabled, try to partition it as fp32
if self.enabled_precision_types == [ConfigPrecisionType.FP32]:
# if only fp32 is enabled, then we can still partition fp32 gemms
# when only fp32 is enabled, then we can still partition fp32 gemms
# even with in a quantized graph
if precision in [
ConfigPrecisionType.STATIC_QUANT,
Expand All @@ -108,6 +108,7 @@ def _overwrite_precision(self, node: torch.fx.Node):
precision = ConfigPrecisionType.FP32
logging.info(f"Overwriting precision, partitioning {node} as FP32")
return True, precision

return False, precision

def get_deps(
Expand Down Expand Up @@ -210,8 +211,11 @@ def _get_bias_deps(
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
) -> Tuple[bool, List[torch.fx.Node]]:
gemm_deps = []
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
# if force force_fp32_dynamic_linear is enabled, then we
if (
precision == ConfigPrecisionType.FP32
and self.force_non_static_weights_for_f32_linear
):
# if force_non_static_weights_for_f32_linear is enabled, then we
# do not partition the weight node
return (True, gemm_deps)

Expand Down Expand Up @@ -287,8 +291,11 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
def _get_weight_deps(
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
) -> Tuple[bool, List[torch.fx.Node]]:
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
# if force fp32_dynamic_linear is enabled, then we
if (
precision == ConfigPrecisionType.FP32
and self.force_non_static_weights_for_f32_linear
):
# if force_non_static_weights_for_f32_linear is enabled, then we
# do not partition the weight node
return (True, [])

Expand Down Expand Up @@ -394,9 +401,11 @@ def __init__(self, **kwargs):
def _get_weight_deps(
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
) -> Tuple[bool, List[torch.fx.Node]]:
# TODO(maxren, T210537195):
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
if (
precision == ConfigPrecisionType.FP32
and self.force_non_static_weights_for_f32_linear
):
# if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
# do not partition the weight node
return (True, [])

Expand Down Expand Up @@ -482,11 +491,11 @@ def find_partition_args(input_node):
node.args = old_args
node.users = old_users

# When using force_fp32_dynamic_linear, we want to get_deps to overwrite the source partition nodes.
# When using force_non_static_weights_for_f32_linear, we want to get_deps to overwrite the source partition nodes.
# Else we want to be greedy.
ret_deps = (
list(set(deps) & set(src_partition.nodes))
if self.force_fp32_dynamic_linear
if self.force_non_static_weights_for_f32_linear
else list(set(deps) | set(src_partition.nodes))
)

Expand All @@ -512,8 +521,11 @@ def __init__(self, **kwargs):
def _get_weight_deps(
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
) -> Tuple[bool, List[torch.fx.Node]]:
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
if (
precision == ConfigPrecisionType.FP32
and self.force_non_static_weights_for_f32_linear
):
# if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
# do not partition the weight node
return (True, [])

Expand Down
4 changes: 3 additions & 1 deletion backends/xnnpack/partition/config/xnnpack_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def __init__(self, **kwargs):
super().__init__()
self.enabled_precision_types = self.supported_precision_types()
# Flag used in GEMMConfig()
self.force_fp32_dynamic_linear = kwargs.get("force_fp32_dynamic_linear", False)
self.force_non_static_weights_for_f32_linear = kwargs.get(
"force_non_static_weights_for_f32_linear", False
)

def get_partition(
self, node: torch.fx.Node, ep: ExportedProgram
Expand Down
4 changes: 2 additions & 2 deletions backends/xnnpack/test/ops/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ def test_linear_qd8_as_fp32(self):
},
)

def test_linear_fp32_with_force_as_mm(self):
def test_linear_with_force_non_static_weights_for_f32_linear(self):
def check_signature(
signature: ExportGraphSignature,
force_flag: bool,
Expand Down Expand Up @@ -907,7 +907,7 @@ def check_signature(
inputs = module.get_inputs()
tester = Tester(module, inputs).export()
partitioner = XnnpackPartitioner(
force_fp32_dynamic_linear=force_flag
force_non_static_weights_for_f32_linear=force_flag
)
if legacy_mode:
tester.to_edge()
Expand Down
8 changes: 5 additions & 3 deletions backends/xnnpack/test/ops/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,20 @@ def test_fp32_lstm(self):
.run_method_and_compare_outputs()
)

def test_fp32_lstm_force_dynamic_linear(self):
def test_lstm_with_force_non_static_weights_for_f32_linear(self):
(
Tester(self.LSTMLinear(32, 32, 10), (torch.rand(1, 32, 32),))
.export()
.to_edge_transform_and_lower(
ToEdgeTransformAndLower(
partitioners=[XnnpackPartitioner(force_fp32_dynamic_linear=True)]
partitioners=[
XnnpackPartitioner(force_non_static_weights_for_f32_linear=True)
]
)
)
.check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"])
# Weights are supplied as input to linears
# Biases are not owned by delegates when force_fp32_dynamic_linear is set
# Biases are not owned by delegates when force_non_static_weights_for_f32_linear is set
.check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0", "p_lstm_bias"])
.to_executorch()
.serialize()
Expand Down
Loading