Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions exir/passes/remove_noop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,15 @@ def call(self, graph_module: GraphModule) -> PassResult:
continue

if node.target == torch.ops.aten.slice_copy.Tensor:
# Only do this check if all the dims are static.
if all(isinstance(dim, int) for dim in orig_tensor.size()):
if orig_tensor.shape == node.meta["val"].shape:
output_tensor = node.meta["val"]
# Only do this check if all dims are static on both sides.
# The output may contain unbacked SymInts (e.g. from
# data-dependent slicing with .item()), so we must check
# both tensors before comparing shapes.
if all(isinstance(dim, int) for dim in orig_tensor.size()) and all(
isinstance(dim, int) for dim in output_tensor.size()
):
if orig_tensor.shape == output_tensor.shape:
# If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
# Otherwise, removing only the op will suffice.
if node.args[0].target in _DEQUANT_OPS:
Expand Down
Loading