Skip to content

[TorchOnnx] Fix ScatterND last unflatten step#4586

Open
juanigp wants to merge 3 commits into
llvm:mainfrom
juanigp:scatternd-fix
Open

[TorchOnnx] Fix ScatterND last unflatten step#4586
juanigp wants to merge 3 commits into
llvm:mainfrom
juanigp:scatternd-fix

Conversation

@juanigp

@juanigp juanigp commented May 28, 2026

Copy link
Copy Markdown

The issue was seen when trying to lower this bevformer-tiny checkpoint.

The ScatterND lowering flattens the first X dimensions of the data tensor for aten.scatter.src, then tries to restore the original ONNX result shape, which can produce invalid operations.

For bevformer-tiny, there is a ScatterND operation:

%667 = onnx.Reshape ... -> [2500,1,256]
%668 = onnx.Gather(%667, axis=1) -> [2500,256]
%669 = onnx.Reshape(%668, ...) -> [2500,1,256]
%670 = onnx.ScatterND(%arg3, %341, %669) -> [2500,1,256]

converted to:

%670 = torch.operator "onnx.ScatterND"(%arg3, %341, %669)
  : (!torch.vtensor<[2500,1,256],f32>,
     !torch.vtensor<[2500,1,2],si64>,
     !torch.vtensor<[2500,1,256],f32>)
 -> !torch.vtensor<[2500,1,256],f32>

and its lowering produces:

%939 = torch.prim.ListConstruct %915, %916, %917 // 2500, 1, 256
%940 = torch.aten.unflatten.int %938, %int0_654, %939
  : !torch.vtensor<[2500,256],f32>, !torch.int, !torch.list<int>
 -> !torch.vtensor<[2500,1,256],f32>

The latter IR is trying to reshape a [2500, 256] tensor into [2500, 1, 256] by replacing dim 0 (2500) by [2500, 1, 256] which is wrong: 2500 at dim 0 should be replaced by [2500, 1] instead (and leave the trailing 256 unaffected).

Fix:
This PR re-writes the unflattening step of ScatterND lowering to properly unflatten the flattened dim 0 into its original shape, and not the original shape of the whole data tensor. Replacing:

        // step 11. Unflatten the collapsed data dims of scatter result.
        if (indicesLastDim == 1) {
          rewriter.replaceOp(binder.op, scatter);
          return success();
        }
        Value unflattenSizeList = Torch::PrimListConstructOp::create(
            rewriter, loc, intListTy, dataDims);
        rewriter.replaceOpWithNewOp<Torch::AtenUnflattenIntOp>(
            binder.op, resultType, scatter, constZero, unflattenSizeList);
        return success();
        // step 11. Unflatten the collapsed indexed prefix of the scatter
        // result to restore the original data shape. Step 9 flattened only
        // dims [0 .. indicesLastDim-1] of data into a single leading axis;
        // the trailing dims [indicesLastDim .. dataRank-1] were preserved
        // through scatter and are already in place. The inverse is therefore
        // an unflatten of axis 0 alone, splitting it back into
        // dataDims[0:indicesLastDim].
        if (indicesLastDim == 1) {
          rewriter.replaceOp(binder.op, scatter);
          return success();
        }
        SmallVector<Value> collapsedDataDims(dataDims.begin(),
                                             dataDims.begin() + indicesLastDim);
        Value unflattenSizeList = Torch::PrimListConstructOp::create(
            rewriter, loc, intListTy, collapsedDataDims);
        rewriter.replaceOpWithNewOp<Torch::AtenUnflattenIntOp>(
            binder.op, resultType, scatter, constZero, unflattenSizeList);
        return success();
      });

@juanigp juanigp changed the title scatternd fix [TorchOnnx] Fix ScatterND last unflatten step May 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant