Skip to content

Add proper subgraph-traversal for qonnx model_wrapper transform function #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4e75d75
Add brainsmith parsing to is_finn
Mar 18, 2025
4c97dea
allow newer versions of protobuf to enable compatilibty with onnxscript
Mar 27, 2025
718d84c
forward metadata
May 7, 2025
7bff09f
forward metadata from gemm_to_matmul step
May 12, 2025
9da4ed8
Merge pull request #181 from fastmachinelearning/feature/forward-meta…
jsmonson May 29, 2025
035264e
basic prototype idea for traversing subgraphs in FINN/QONNX/Brainsmith
May 29, 2025
a308bf2
Merge remote-tracking branch 'origin/custom/brainsmith' into feature/…
Jun 5, 2025
3117b33
update the approach to handle transforms that call transforms
Jun 6, 2025
e42f059
Merge branch 'main' of https://github.com/fastmachinelearning/qonnx i…
Jun 6, 2025
fab7c89
update model wrapper for traversal
Jun 6, 2025
e1c90e8
add subgraph traversal tests
Jun 6, 2025
56fcdf1
add NestTransform class and test to ensure that subgraph traversal ha…
Jun 6, 2025
e2ba0e4
bugfix for subgraph traversal in qonnx; update tests, fixed corner ca…
Jun 9, 2025
9264cac
update modelwrapper to store metadata in graph proto rather than mode…
Jun 10, 2025
97591dc
Merge remote-tracking branch 'origin/feature/switch-to-graph-metadata…
Jun 10, 2025
e452f74
fix tests now that metadata is stored in graphs
Jun 10, 2025
fdc3116
Revert "forward metadata from gemm_to_matmul step"
Jun 11, 2025
2946c2d
Revert "forward metadata"
Jun 11, 2025
650d7fb
Revert "allow newer versions of protobuf to enable compatilibty with …
Jun 11, 2025
457cf66
remove commented code in Transformation class
Jun 11, 2025
08f3e61
cleanup
Jun 11, 2025
d8eacda
update comments
Jun 11, 2025
a769c1d
revert changes from custom/brainsmith
Jun 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions src/qonnx/core/modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,13 @@ def analysis(self, analysis_fxn):
"""Runs given anaylsis_fxn on this model and return resulting dict."""
return analysis_fxn(self)

def transform(self, transformation, make_deepcopy=True, cleanup=True):
def transform(self, transformation, make_deepcopy=True, cleanup=True, apply_to_subgraphs=False):
"""Applies given Transformation repeatedly until no more changes can be made
and returns a transformed ModelWrapper instance.

- make_deepcopy : operates on a new (deep)copy of model.
- cleanup : execute cleanup transformations before returning
- apply_to_subgraphs : if True, transformation is applied to all subgraphs of the model
"""
transformed_model = self
if make_deepcopy:
Expand All @@ -145,6 +146,31 @@ def transform(self, transformation, make_deepcopy=True, cleanup=True):
(transformed_model, model_was_changed) = transformation.apply(transformed_model)
if cleanup:
transformed_model.cleanup()

if apply_to_subgraphs:
for node in transformed_model.model.graph.node:
transformed_subgraph_attrs = []
for idx, attr in enumerate(node.attribute):
if attr.type == onnx.AttributeProto.GRAPH:
# this is a subgraph, add it to the list
subgraph = ModelWrapper(util.qonnx_make_model(attr.g))
# extract all model metadata from loop model and apply to body
for metadata in transformed_model.model.metadata_props:
subgraph.set_metadata_prop(metadata.key, metadata.value)
# apply the transformation to the subgraph
subgraph = subgraph.transform(transformation, make_deepcopy, cleanup, apply_to_subgraphs)
# copy model metadata from the subgraph to the parent model
for metadata in subgraph.model.metadata_props:
transformed_model.set_metadata_prop(metadata.key, metadata.value)
# update the new subgraph in the attrubute
transformed_subgraph_attrs.append((idx, onnx.helper.make_attribute(attr.name, subgraph.model.graph)))
# replace the attributes in the node with the transformed subgraph attributes
for idx, new_attr in transformed_subgraph_attrs:
# remove the old attribute
node.attribute.pop(idx)
# add the new attribute
node.attribute.insert(idx, new_attr)

return transformed_model

def cleanup(self):
Expand Down Expand Up @@ -566,20 +592,20 @@ def get_tensor_fanout(self, tensor_name):
def get_metadata_prop(self, key):
"""Returns the value associated with metadata_prop with given key,
or None otherwise."""
metadata_prop = util.get_by_name(self.model.metadata_props, key, "key")
metadata_prop = util.get_by_name(self.model.graph.metadata_props, key, "key")
if metadata_prop is None:
return None
else:
return metadata_prop.value

def set_metadata_prop(self, key, value):
"""Sets metadata property with given key to the given value."""
metadata_prop = util.get_by_name(self.model.metadata_props, key, "key")
metadata_prop = util.get_by_name(self.model.graph.metadata_props, key, "key")
if metadata_prop is None:
metadata_prop = onnx.StringStringEntryProto()
metadata_prop.key = key
metadata_prop.value = value
self.model.metadata_props.append(metadata_prop)
self.model.graph.metadata_props.append(metadata_prop)
else:
metadata_prop.value = value

Expand Down Expand Up @@ -695,3 +721,22 @@ def set_tensor_sparsity(self, tensor_name, sparsity_dict):
qa.tensor_name = tensor_name
qa.quant_parameter_tensor_names.append(dt)
qnt_annotations.append(qa)

def get_subgraph_modelwrappers(self):
"""Find all subgraphs in the model by looking for graphs in node attributes.
Return them as a list of ModelWrappers in breadth-first search order."""

nodes_to_search = []
nodes_to_search.extend(self.graph.node)
subgraphs = []
while len(nodes_to_search) > 0:
node = nodes_to_search.pop(0)
for attr in node.attribute:
if attr.type == onnx.AttributeProto.GRAPH:
# this is a subgraph, add it to the list
subgraph = ModelWrapper(util.qonnx_make_model(attr.g))
subgraphs.append(subgraph)
# add the subgraph nodes to the search list
nodes_to_search.extend(subgraph.graph.node)

return subgraphs
21 changes: 21 additions & 0 deletions src/qonnx/transformation/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,27 @@ def apply(self, model):
for node_idx, n in enumerate(node_list):
node_pred = model.find_direct_predecessors(n)
if node_pred is None:
# if connected only to input and output it doesn't matter where it goes
# but should not be removed from the grpah
if len(n.input) == 0 or len(n.output) == 0:
continue

connected_to_graph_inputs_only = True
for inp in n.input:
tensor_names = [vi.name for vi in model.graph.input]
if inp not in tensor_names:
connected_to_graph_inputs_only = False
break
connected_to_graph_outputs_only = True
for outp in n.output:
tensor_names = [vi.name for vi in model.graph.output]
if outp not in tensor_names:
connected_to_graph_outputs_only = False
break
if connected_to_graph_inputs_only and connected_to_graph_outputs_only:
graph_dependencies[node_idx] = set()
continue

# Will also eliminate nodes that are floating around for some reason
continue

Expand Down
197 changes: 197 additions & 0 deletions tests/core/test_subgraph_traversal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import pytest
from collections import Counter

from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.base import Transformation

from qonnx.util.basic import qonnx_make_model, get_by_name
import onnx
from onnx import helper

# Helper to recursively build a graph with subgraphs attached to nodes
def make_graph(tree):
"""
Recursively build a ModelWrapper tree from a nested tuple/list structure.
Each graph will have one node per subgraph, with the subgraph attached as a node attribute.
Example input:
("top", [("sub1", []), ("sub2", [("sub2_1", [])])])
Returns the top-level ModelWrapper.
"""
name, subtrees = tree
# Create subgraphs recursively
subgraph_nodes = []
inputs = []
outputs = []
for subtree in subtrees:
subgraph = make_graph(subtree)
sg_name_in = f"{subgraph.name}_in"
sg_name_out = f"{subgraph.name}_out"
inputs.append(onnx.helper.make_tensor_value_info(sg_name_in, onnx.TensorProto.FLOAT, [4, 4]))
outputs.append(onnx.helper.make_tensor_value_info(sg_name_out, onnx.TensorProto.FLOAT, [4, 4]))
# Attach subgraph as attribute to node
node = helper.make_node(
op_type="SubgraphNode", # dummy op_type
inputs=[sg_name_in],
outputs=[sg_name_out],
name=f"{subgraph.name}_node",
)
# ONNX expects subgraphs as AttributeProto, so we set it below
attr = onnx.helper.make_attribute("body", subgraph)
node.attribute.append(attr)
subgraph_nodes.append(node)
# Create the graph for this level
graph = helper.make_graph(
nodes=subgraph_nodes,
name=name,
inputs=inputs,
outputs=outputs,
)

return graph

def make_subgraph_model(tree):
"""
Build a ModelWrapper with a graph structure based on the provided tree.
The tree is a nested tuple/list structure where each node can have subgraphs.
"""
return ModelWrapper(qonnx_make_model(make_graph(tree)))


class DummyTransform(Transformation):
def __init__(self):
self.visited = list()

def apply(self, model_wrapper):
# get the name of the graph being transformed
graph_name = model_wrapper.model.graph.name
# set a metadata property to test whether metadata is preserved
model_wrapper.set_metadata_prop(graph_name, "visited")
# add a dummy node to the graph to simulate a transformation
# to see if the subgraph transformation is presered

dummy_name_in = f"{graph_name}_dummy_in"
dummy_name_out = f"{graph_name}_dummy_out"
model_wrapper.model.graph.input.append(helper.make_tensor_value_info(dummy_name_in, onnx.TensorProto.FLOAT, [4, 4]))
model_wrapper.model.graph.output.append(helper.make_tensor_value_info(dummy_name_out, onnx.TensorProto.FLOAT, [4, 4]))
model_wrapper.model.graph.node.append(
helper.make_node(
"DummyNode", # dummy op_type
inputs=[dummy_name_in],
outputs=[dummy_name_out],
name=f"{graph_name}_dummy_node",
)
)

# collect the name of the graph being transformed to check how many times each graph was visited
self.visited.append(graph_name)
#import pdb; pdb.set_trace()
return model_wrapper, False

class NestedTransform(Transformation):
def __init__(self):
self.dummy_transform = DummyTransform()
def apply(self, model_wrapper):
return model_wrapper.transform(self.dummy_transform), False

def get_subgraph_names(tree):
"""
Recursively collect the names of all subgraphs in the tree structure.
"""
names = set()

def traverse(tree):
name = tree[0]
subgraphs = tree[1]
names.add(name)
for subgraph in subgraphs:
traverse(subgraph)

traverse(tree)
return names


def check_all_visted_once(tree, transform):
"""
Check that all subgraphs in the tree structure were visited exactly once.
"""
visited = transform.visited
expected = get_subgraph_names(tree)
assert Counter(visited) == Counter(expected), f"Visited: {visited}, Expected: {expected}"

def check_all_subgraphs_transformed(graph):
"""
Check that all subgraphs in the tree structure have been transformed.
"""

# look for the optype "DummyNode" in the model graph
dummynode_found = False
for node in graph.node:
if node.op_type == "DummyNode":
dummynode_found = True
break
if not dummynode_found:
raise AssertionError(f"DummyNode not found in the transformed model graph {graph.name}")

# check that metadata is set for all subgraphs
def get_metadata_props(graph, key):
metadata_prop = get_by_name(graph.metadata_props, key, "key")
if metadata_prop is None:
return None
else:
return metadata_prop.value

assert(get_metadata_props(graph, graph.name) == "visited"), f"Metadata for {graph.name} not set correctly"

# recursively check all subgraphs
for node in graph.node:
for attr in node.attribute:
if attr.type == onnx.AttributeProto.GRAPH:
check_all_subgraphs_transformed(attr.g)

@pytest.mark.parametrize("cleanup", [False, True])
@pytest.mark.parametrize("make_deepcopy", [False, True])
@pytest.mark.parametrize("model, apply_to_subgraphs",
[(make_subgraph_model(("top", [])), True),
(make_subgraph_model(("top", [])), False),
(make_subgraph_model(("top", [("sub1", [])])), False)])
def test_no_traversal(model, cleanup, make_deepcopy, apply_to_subgraphs):
# Check that the top-level model is transformed exactly once when there are no subgraphs.
# Check that the top-level model is transformed exactly once when there are subgraphs, but apply_to_subgraphs is False.
# This should always be done correctly regardless of cleanup and make_deepcopy.

transform = DummyTransform()
t_model = model.transform(transform, cleanup, make_deepcopy, apply_to_subgraphs)

assert transform.visited == ["top"]
assert t_model.get_metadata_prop("top") == "visited"


@pytest.mark.parametrize("cleanup", [False, True])
@pytest.mark.parametrize("make_deepcopy", [False, True])
@pytest.mark.parametrize("tree", [("top", [("sub1", []), ("sub2", [])]),
("top", [("sub1", [("sub1_1", []), ("sub1_2",[])]), ("sub2", [("sub2_1", [])])])])
def test_traversal(tree, cleanup, make_deepcopy):
# Check that the top-level model and all subgraphs are transformed when apply_to_subgraphs is True.
# This should always be done correctly regardless of cleanup and make_deepcopy.
print(f"Testing tree: {tree}, cleanup: {cleanup}, make_deepcopy: {make_deepcopy}")
model = make_subgraph_model(tree)
transform = DummyTransform()
t_model = model.transform(transform, cleanup, make_deepcopy, apply_to_subgraphs=True)

check_all_visted_once(tree, transform)
check_all_subgraphs_transformed(t_model.model.graph)


@pytest.mark.parametrize("cleanup", [False, True])
@pytest.mark.parametrize("make_deepcopy", [False, True])
@pytest.mark.parametrize("tree", [("top", [("sub1", []), ("sub2", [])]),
("top", [("sub1", [("sub1_1", []), ("sub1_2",[])]), ("sub2", [("sub2_1", [])])])])
def test_traversal_nested(tree, cleanup, make_deepcopy):
# Check that the top-level model and all subgraphs are transformed when apply_to_subgraphs is True.
# This should always be done correctly regardless of cleanup and make_deepcopy.
model = make_subgraph_model(tree)
transform = NestedTransform()
t_model = model.transform(transform, cleanup, make_deepcopy, apply_to_subgraphs=True)

check_all_visted_once(tree, transform.dummy_transform)
check_all_subgraphs_transformed(t_model.model.graph)
Loading