Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,14 @@ def duplicate_constant_node(
for arg in users[ith].args
]
)
new_kwargs = dict(
{
(
key,
(
value
if value != constant_or_attribute_node
else copied_constant_or_attribute_node
),
)
for key, value in users[ith].kwargs
}
)
new_kwargs = {
key: (
value
if value != constant_or_attribute_node
else copied_constant_or_attribute_node
)
for key, value in users[ith].kwargs.items()
}
users[ith].args = new_args
users[ith].kwargs = new_kwargs
if old_input_spec.kind == InputKind.CONSTANT_TENSOR:
Expand Down
40 changes: 40 additions & 0 deletions exir/backend/test/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,43 @@ def forward(self, x):
FileCheck().check("b_const_copy_0").run(
edge.exported_program().graph_module.code
)

def test_duplicate_constant_node_with_kwargs_users(self) -> None:
"""
Test that duplicate_constant_node correctly handles nodes where users
reference the constant via kwargs (not just args).
"""

class ModelWithBuffer(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("const_buffer", torch.tensor([1.0, 2.0, 3.0]))

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.const_buffer + x * self.const_buffer

model = export(ModelWithBuffer(), (torch.randn(3),), strict=True).module()
edge = exir.to_edge(torch.export.export(model, (torch.randn(3),), strict=True))

# Find the buffer node
buffer_node = None
for node in edge.exported_program().graph.nodes:
if node.op == "placeholder" and is_buffer(edge.exported_program(), node):
buffer_node = node
break

# Move buffer reference from args to kwargs for one user
users = list(buffer_node.users.keys())
user = users[1]
user.args = tuple(a for a in user.args if a is not buffer_node)
user.kwargs = {"other": buffer_node}

# Patch validation since we modified the graph
edge.exported_program()._validate = lambda: None

copied_nodes = duplicate_constant_node(
edge.exported_program(), buffer_node.name
)

self.assertEqual(len(copied_nodes), 1)
self.assertNotEqual(user.kwargs["other"], buffer_node)
Loading