Skip to content

[TorchToTosa] Add lowering for AtenSortOp#4581

Open
alosim01 wants to merge 5 commits into
llvm:mainfrom
alosim01:add-aten-sort-op
Open

[TorchToTosa] Add lowering for AtenSortOp#4581
alosim01 wants to merge 5 commits into
llvm:mainfrom
alosim01:add-aten-sort-op

Conversation

@alosim01

Copy link
Copy Markdown

Summary

Adds TorchToTosa lowering support for torch.aten.sort.

The lowering supports statically-shaped ranked tensors with floating-point element types,
constant dim, and constant descending. It emits repeated tosa.argmax/tosa.gather
selections, masks selected elements with a sentinel, and returns both sorted values and
indices.

This also handles the common decomposed topk pattern where sort results are immediately
prefix-sliced, lowering only the requested prefix instead of sorting the full dimension.

Details

  • Lowers full aten.sort for ascending and descending order.
  • Supports sorting along non-last dimensions by transposing to the selection dimension and
    transposing back.
  • Produces i64 indices when required by the Torch result type.
  • Handles rank-zero tensors by returning the input value and index 0.
  • Tightens AtenSortOp folding so invalid static dims do not fold incorrectly.
  • Adds regression tests for full sort, decomposed top-k prefix slicing, rank-zero sort, and
    invalid-dim no-fold behavior.

@alosim01 alosim01 force-pushed the add-aten-sort-op branch from 5fb0d05 to 3a5ccc4 Compare May 27, 2026 09:23

@Lallapallooza Lallapallooza left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the patch, few suggestions.

Comment thread lib/Conversion/TorchToTosa/TorchToTosa.cpp
Comment thread lib/Conversion/TorchToTosa/TorchToTosa.cpp
Comment thread lib/Conversion/TorchToTosa/TorchToTosa.cpp
Comment thread lib/Conversion/TorchToTosa/TorchToTosa.cpp
Comment thread lib/Conversion/TorchToTosa/TorchToTosa.cpp Outdated
Comment thread lib/Conversion/TorchToTosa/TorchToTosa.cpp
Comment thread test/Conversion/TorchToTosa/basic.mlir Outdated
Comment thread projects/pt1/e2e_testing/xfail_sets.py
@alosim01 alosim01 force-pushed the add-aten-sort-op branch 2 times, most recently from bca17c9 to 5c4b7ba Compare June 16, 2026 09:04

@Lallapallooza Lallapallooza left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for update, few questions.

@@ -281,15 +281,46 @@ static FailureOr<Value> createIntOrFloatCompareOp(PatternRewriter &rewriter,
}

if (isa<mlir::FloatType>(elementType)) {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we split the TMTensor comparator change out of this TOSA sort PR.

Comment on lines +1603 to +1631
Type elementTy = selfTy.getElementType();
if (!elementTy.isF32() && !elementTy.isF16() && !elementTy.isBF16())
return rewriter.notifyMatchFailure(
op, "only f32, f16, and bf16 element types are supported");

bool descending;
if (!matchPattern(op.getDescending(), m_TorchConstantBool(&descending)))
return rewriter.notifyMatchFailure(
op, "unimplemented: only constant descending value is supported");

int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "unimplemented: only constant dim value is supported");

int64_t rank = selfTy.getRank();
if (rank == 0) {
if (dim != 0 && dim != -1)
return rewriter.notifyMatchFailure(op, "scalar sort dim is invalid");

auto indicesTy = cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult(1).getType()));
Value indices = tosa::getConstTensor<int32_t>(rewriter, op, 0, {}).value();
if (indicesTy.getElementType().isInteger(64))
indices = tosa::CastOp::create(rewriter, loc, indicesTy, indices);
else if (indices.getType() != indicesTy)
indices = tensor::CastOp::create(rewriter, loc, indicesTy, indices);
rewriter.replaceOp(op, {self, indices});
return success();

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scalar sort fast path is after the float-only dtype check, so rank-zero integer are rejected before reaching the branch that just returns. This path does not need a numeric comparison. Can we move the rank-zero branch before the element-type gate?

Comment on lines +1666 to +1685
auto parsePrefixSlice = [&](AtenSliceTensorOp slice,
int64_t &sliceK) -> bool {
int64_t sliceDim;
if (!matchPattern(slice.getDim(), m_TorchConstantInt(&sliceDim)))
return false;
sliceDim = toPositiveDim(sliceDim, rank);
if (sliceDim != dim)
return false;
int64_t start;
if (!matchPattern(slice.getStart(), m_TorchConstantInt(&start)) ||
start != 0)
return false;
int64_t step;
if (!matchPattern(slice.getStep(), m_TorchConstantInt(&step)) ||
step != 1)
return false;
if (!matchPattern(slice.getEnd(), m_TorchConstantInt(&sliceK)))
return false;
return sliceK > 0 && sliceK <= dimSize;
};

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

topk(k=0) is valid in PyTorch and should return empty values plus empty int64 indices, but this lowering misses that case. After decomposition, topk becomes sort followed by slices ending at 0, and parsePrefixSlice rejects that because it requires sliceK > 0. As a result, large inputs can fall back to full sort and hit the 128-element cap, while smaller inputs still run into the zero-sized-output rejection. Can we add a direct k == 0 fast path + coverage.

Comment on lines +3203 to +3206
if (dimInt < 0)
dimInt += rank;
if (dimInt < 0 || dimInt >= rank)
return failure();

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second if always True if first is True, correct?

Comment on lines +6689 to +6692
@export
@annotate_args([None, ([1, 6], torch.float32, True)])
def forward(self, x):
return torch.ops.aten.topk(x, k=6, dim=-1, largest=False, sorted=True)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This NaN/Inf smallest-topk test uses k=6 on a length-6 input, so it mostly exercises full ordering rather than partial top-k selection. Can we change this test or add a second one with k < dim_size so the NaN-aware top-k behavior is actually covered?

Comment on lines 2349 to 2364
SmallVector<mlir::Complex<APFloat>> values;
for (auto i : llvm::seq<unsigned>(0, matrixType.getDimSize(0))) {
for (auto j : llvm::seq<unsigned>(0, matrixType.getDimSize(1))) {
double v = scale * i * j;
double realV = cos(v);
double imagV = -sin(v);

bool unused;
APFloat real(realV);
real.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven,
&unused);
APFloat imag(imagV);
imag.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven,
&unused);

values.push_back(mlir::Complex<APFloat>(real, imag));

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need these changes?

if (dimInt < 0)
dimInt += operandType.getSizes().size();
if (dimAttribute) {
int64_t rank = operandType.getSizes().size();

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain why old logic was wrong and new is correct?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants