Skip to content

Extend constant prop pass to work with int/float/etc scalars and fix input specs. #2950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions exir/passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ python_library(
],
deps = [
"//caffe2:torch",
"//executorch/exir/dialects:lib",
"//executorch/exir/dialects/edge:lib",
],
)

Expand Down
342 changes: 259 additions & 83 deletions exir/passes/constant_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,58 +4,145 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections import OrderedDict
from typing import cast, Mapping, Optional

import torch
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from torch._export.utils import (
get_buffer,
get_lifted_tensor_constant,
get_param,
is_buffer,
is_lifted_tensor_constant,
is_param,
)
from torch._guards import detect_fake_mode
from torch.export import ExportedProgram
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
from torch.utils import _pytree as pytree


# Avoid propagating constants for `exir.ops.edge.aten.full.default`.
# Propagating aten.full can significantly increase compiled model size.
_DEFAULT_SKIP_TARGETS = {exir_ops.edge.aten.full.default}

_PRIMITIVE_TYPES = (
float,
int,
bool,
str,
torch.Tensor,
torch.device,
torch.dtype,
torch.layout,
)

def is_const(arg, exported_program, const_data_list) -> bool:

def is_const(
arg,
exported_program: ExportedProgram,
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
) -> bool:
if isinstance(arg, (tuple, list)):
return all(is_const(x, exported_program, const_data_list) for x in arg)
return all(is_const(x, exported_program, const_node_to_tensor) for x in arg)
elif isinstance(arg, dict):
return all(is_const(x, exported_program, const_data_list) for x in arg.values())
elif not isinstance(arg, torch.fx.Node) or arg.op != "placeholder":
return all(
is_const(x, exported_program, const_node_to_tensor) for x in arg.values()
)
elif isinstance(arg, _PRIMITIVE_TYPES):
return True
elif not isinstance(arg, torch.fx.Node):
return False
elif (
is_param(exported_program, arg)
or is_buffer(exported_program, arg)
or arg.name in const_data_list
):
elif arg in const_node_to_tensor:
return True
return False


def get_data(exported_program, arg):
def get_data(
arg,
exported_program: ExportedProgram,
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
):
if isinstance(arg, (tuple, list)):
return [get_data(exported_program, x) for x in arg]
elif is_param(exported_program, arg):
return get_param(exported_program, arg)
elif is_buffer(exported_program, arg):
return get_buffer(exported_program, arg)
return type(arg)(
get_data(x, exported_program, const_node_to_tensor) for x in arg
)
elif isinstance(arg, _PRIMITIVE_TYPES):
return arg
elif arg in const_node_to_tensor:
return const_node_to_tensor[arg]
return None


def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
def get_constant_placeholder_dict(
exported_program: ExportedProgram,
) -> OrderedDict[torch.fx.Node, torch.Tensor]:
"""
This pass is for constant propagation for Exported Program with lifted parameters,
as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
Returns a dictionary of placeholder node -> constant tensor.
"""
if (
len([node for node in exported_program.graph.nodes if node.op == "placeholder"])
== 0
):
return exported_program
const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict()
for node in exported_program.graph.nodes:
if node.op != "placeholder":
continue

if is_param(exported_program, node):
const_node_to_tensor[node] = cast(
torch.Tensor, get_param(exported_program, node)
)
elif is_buffer(exported_program, node):
const_node_to_tensor[node] = cast(
torch.Tensor, get_buffer(exported_program, node)
)
elif is_lifted_tensor_constant(exported_program, node):
const_node_to_tensor[node] = cast(
torch.Tensor, get_lifted_tensor_constant(exported_program, node)
)
return const_node_to_tensor

has_cond = [
node
for node in exported_program.graph.nodes
if node.target == torch.ops.higher_order.cond
]
if len(has_cond) > 0:
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")

def get_propagated_const_tensor_dict(
exported_program: ExportedProgram,
custom_skip_targets: Optional[set[EdgeOpOverload]],
) -> OrderedDict[torch.fx.Node, torch.Tensor]:
"""
Propagates constants and returns a dictionary of node->constant tensors.
"""
# Initialize dict with all constant placeholders.
const_node_to_tensor = get_constant_placeholder_dict(exported_program)

all_skip_targets: set[EdgeOpOverload] = set()
# Default set of targets to skip.
all_skip_targets.update(_DEFAULT_SKIP_TARGETS)
if custom_skip_targets is not None:
all_skip_targets.update(custom_skip_targets)

for node in exported_program.graph.nodes:
if node.op != "call_function" or node.target in all_skip_targets:
continue

if not is_const(
node.args,
exported_program,
const_node_to_tensor,
):
continue

args_data, kwargs_data = pytree.tree_map(
lambda x: get_data(x, exported_program, const_node_to_tensor),
(node.args, node.kwargs),
)

# Execute the `node.target` and create a new propagated constant tensor.
prop_constant_tensor = node.target(*args_data, **kwargs_data)
const_node_to_tensor[node] = prop_constant_tensor

return const_node_to_tensor


def get_first_user_input(exported_program: ExportedProgram) -> torch.fx.Node:
"""Returns the first user input node in the graph."""
first_user_input = None
for node in exported_program.graph.nodes:
if (
Expand All @@ -64,11 +151,42 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
):
first_user_input = node
break
return first_user_input


def replace_with_constant_node(
node: torch.fx.Node,
prop_constant_tensor: torch.Tensor,
first_user_input: torch.fx.Node,
fake_mode,
exported_program: ExportedProgram,
) -> tuple[torch.fx.Node, str]:
# Add `prop_constant_tensor` to program.state_dict.
prop_constant_tensor_fqn = f"_prop_tensor_constant{len(exported_program.constants)}"
exported_program.constants[prop_constant_tensor_fqn] = prop_constant_tensor

# Insert a new placeholder node for the propagated constant tensor.
with exported_program.graph.inserting_before(first_user_input):
const_placeholder_node = exported_program.graph.placeholder(
prop_constant_tensor_fqn
)

# Update the meta data of the new placeholder (buffer) node.
for k, v in node.meta.items():
const_placeholder_node.meta[k] = v
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
prop_constant_tensor, static_shapes=True
)
const_placeholder_node.meta["val"].constant = prop_constant_tensor

# Replace the original node with the new constant node.
node.replace_all_uses_with(const_placeholder_node)
exported_program.graph.erase_node(node)

return const_placeholder_node, prop_constant_tensor_fqn

buffers = exported_program.graph_signature.buffers
prop_constant_data = []
const_data_to_be_removed = set()

def get_fake_mode(exported_program: ExportedProgram):
fake_mode = detect_fake_mode(
tuple(
node.meta["val"]
Expand All @@ -77,57 +195,115 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
)
)
assert fake_mode is not None
return fake_mode


def erase_constant_node(
exported_program: ExportedProgram,
node: torch.fx.Node,
) -> None:
# Remove corresponding tensor from param/constants dict.
signature = exported_program.graph_signature
if name := signature.inputs_to_parameters.pop(node.name, None):
exported_program.state_dict.pop(name, None)
elif name := signature.inputs_to_lifted_tensor_constants.pop(node.name, None):
exported_program.constants.pop(name, None)
elif name := signature.inputs_to_buffers.pop(node.name, None):
exported_program.constants.pop(name, None)
exported_program.state_dict.pop(name, None)

# Remove from graph.
exported_program.graph.erase_node(node)


def create_constant_nodes_and_return_specs(
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
exported_program: ExportedProgram,
) -> dict[str, InputSpec]:
"""
Creates constant nodes for all entries in `const_node_to_tensor` and returns a node.name -> InputSpec dict.
"""
name_to_spec_dict: dict[str, InputSpec] = {}

fake_mode = get_fake_mode(exported_program)
first_user_input = get_first_user_input(exported_program)

# Iterate over nodes in reverse order.
for node, prop_constant_tensor in reversed(const_node_to_tensor.items()):
if all(x in const_node_to_tensor for x in node.users):
# All users of this constant node are also constant, so we don't need to create a new constant node.
erase_constant_node(exported_program, node)
continue

if node.op == "placeholder":
continue

const_placeholder_node, prop_constant_tensor_fqn = replace_with_constant_node(
node, prop_constant_tensor, first_user_input, fake_mode, exported_program
)

# Create input spec for lifted constant.
name_to_spec_dict[const_placeholder_node.name] = InputSpec(
kind=InputKind.CONSTANT_TENSOR,
arg=TensorArgument(name=const_placeholder_node.name),
target=prop_constant_tensor_fqn,
persistent=True,
)
return name_to_spec_dict


def constant_prop_pass(
exported_program: ExportedProgram,
custom_skip_targets: Optional[set[EdgeOpOverload]] = None,
) -> ExportedProgram:
"""
This pass is for constant propagation for Exported Program with lifted parameters,
as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.

Args:
exported_program: The ExportedProgram to perform constant propagation on.
custom_skip_targets: Optional set of EdgeOpOverload targets to skip during constant propagation.

Returns:
The modified ExportedProgram with constant propagation applied.
"""
if (
len([node for node in exported_program.graph.nodes if node.op == "placeholder"])
== 0
):
return exported_program

has_control_flow = [
node
for node in exported_program.graph.nodes
if node.target == torch.ops.higher_order.cond
]
if len(has_control_flow) > 0:
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")

const_node_to_tensor = get_propagated_const_tensor_dict(
exported_program, custom_skip_targets
)

# Get old input specs.
name_to_spec_dict = {
s.arg.name: s for s in exported_program.graph_signature.input_specs
}
# Add the new constants to input specs dict.
name_to_spec_dict.update(
create_constant_nodes_and_return_specs(const_node_to_tensor, exported_program)
)

# Generate new input spec.
new_input_specs = []
for node in exported_program.graph.nodes:
if node.op == "call_function":
constant_data_name_list = [
input_spec.target for input_spec in prop_constant_data
]
if is_const(node.args, exported_program, constant_data_name_list):
args_data = [get_data(exported_program, arg) for arg in node.args]
kwargs_data = node.kwargs
const_data_to_be_removed.update(node.args)
prop_constant_tensor = node.target(*args_data, **kwargs_data)
prop_constant_tensor_fqn = f"_prop_tensor_constant{len(buffers)}"

with exported_program.graph.inserting_before(first_user_input):
const_placeholder_node = exported_program.graph.placeholder(
prop_constant_tensor_fqn
)
# Update the meta data of the new placeholder (buffer) node
for k, v in node.meta.items():
const_placeholder_node.meta[k] = v
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
prop_constant_tensor, static_shapes=True
)
const_placeholder_node.meta["val"].constant = prop_constant_tensor

node.replace_all_uses_with(const_placeholder_node)
exported_program.graph.erase_node(node)
prop_constant_node_input_spec = InputSpec(
kind=InputKind.BUFFER,
arg=TensorArgument(name=const_placeholder_node.name),
target=prop_constant_tensor_fqn,
persistent=True,
)
prop_constant_data.append(prop_constant_node_input_spec)
buffers.append(prop_constant_tensor_fqn)
exported_program.state_dict[prop_constant_tensor_fqn] = (
prop_constant_tensor
)
exported_program.graph_signature.input_specs.append(
prop_constant_node_input_spec
)

# Remove the propogated buffer from the state dict
for node in exported_program.graph.nodes:
if (
node.op == "placeholder"
and node in const_data_to_be_removed
and len(node.users) == 0
):
exported_program.state_dict.pop(node.name, None)
exported_program.graph.erase_node(node)
if node.op != "placeholder":
continue
new_input_specs.append(name_to_spec_dict[node.name])
exported_program.graph_signature.input_specs = new_input_specs

# Cleanup the graph.
exported_program.graph.eliminate_dead_code()
exported_program.graph_module.recompile()

return exported_program
Loading