Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,8 @@ class ExecutorchBackendConfig:
# If set to true, all trainable weights will be stored in a separate file,
# external to the PTE file.
external_mutable_weights: bool = False

# If set to true, all mutable buffers will have their fully qualified names
# serialized in the PTE file. Its value is ignored if mutable buffers are not
# memory planned as the names must be serialized in that case.
emit_mutable_buffer_names: bool = False
2 changes: 2 additions & 0 deletions exir/emit/_emit_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def emit_program(
methods: Union[ExportedProgram, Dict[str, ExportedProgram]],
emit_stacktrace: bool = False,
prim_getters: Optional[Dict[str, Any]] = None,
emit_mutable_buffer_names: bool = False,
) -> EmitterOutput:
"""
Given a exported program, it returns the program in the format
Expand Down Expand Up @@ -163,6 +164,7 @@ def emit_program(
operator_cache={},
delegate_cache={},
emit_stacktrace=emit_stacktrace,
emit_mutable_buffer_names=emit_mutable_buffer_names,
)

gm = _remove_non_user_outputs(exported_program)
Expand Down
10 changes: 9 additions & 1 deletion exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class _EmitterState:
# delegate_cache: the key is hash(delegated_payload) and the value is the index in delegates
delegate_cache: Dict[str, int]
emit_stacktrace: bool
emit_mutable_buffer_names: bool

spec2id_dict: Dict[TensorSpec, int] = field(default_factory=dict)

Expand Down Expand Up @@ -1610,7 +1611,7 @@ def _find_fqn_for_placeholder(
)
return fqn, is_mutable_buffer

def placeholder(
def placeholder( # noqa: C901
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
) -> _AbstractValue:
"""Emits the value within the placeholder node.
Expand Down Expand Up @@ -1639,6 +1640,13 @@ def placeholder(
else:
spec.extra_tensor_info.fully_qualified_name = fqn
spec.extra_tensor_info.location = TensorDataLocation.EXTERNAL
if self.emitter_state.emit_mutable_buffer_names and is_mutable_buffer:
if spec.extra_tensor_info is None:
spec.extra_tensor_info = ExtraTensorInfo(
fully_qualified_name=fqn, location=TensorDataLocation.SEGMENT
)
else:
spec.extra_tensor_info.fully_qualified_name = fqn

# From the fqn find the corresponding tensor
real_tensor = None
Expand Down
24 changes: 24 additions & 0 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,3 +1819,27 @@ def forward(self, input, label):
]
self.assertEqual(external_map["net.linear.weight"], 0)
self.assertEqual(external_map["net.linear.bias"], 1)

def test_emit_mutable_buffer_names(self) -> None:
class Net(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 2)
self.register_buffer("buffer", torch.zeros(1, 2))

def forward(self, x):
self.buffer.add_(1)
return self.linear(x) + self.buffer

net = Net()

ep = export(net, (torch.randn(1, 2),), strict=True)
# Lower the graph to edge dialect.
ep = to_edge(ep)
# Lower the graph to executorch.
ep = ep.to_executorch(
config=ExecutorchBackendConfig(emit_mutable_buffer_names=True)
)
for val in ep.executorch_program.execution_plan[0].values:
if isinstance(val, Tensor) and val.extra_tensor_info:
self.assertEqual(val.extra_tensor_info.fully_qualified_name, "buffer")
1 change: 1 addition & 0 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,6 +1612,7 @@ def __init__(
self._execution_programs,
backend_config.emit_stacktrace,
self._config_methods,
backend_config.emit_mutable_buffer_names,
)

# Serialize emitter output, ready to be written to a file.
Expand Down
Loading