[TorchToTosa] Add lowering for AtenSortOp#4581
Conversation
5fb0d05 to
3a5ccc4
Compare
Lallapallooza
left a comment
There was a problem hiding this comment.
Thanks for the patch, few suggestions.
bca17c9 to
5c4b7ba
Compare
5c4b7ba to
9fad881
Compare
Lallapallooza
left a comment
There was a problem hiding this comment.
Thanks for update, few questions.
| @@ -281,15 +281,46 @@ static FailureOr<Value> createIntOrFloatCompareOp(PatternRewriter &rewriter, | |||
| } | |||
|
|
|||
| if (isa<mlir::FloatType>(elementType)) { | |||
There was a problem hiding this comment.
Can we split the TMTensor comparator change out of this TOSA sort PR.
| Type elementTy = selfTy.getElementType(); | ||
| if (!elementTy.isF32() && !elementTy.isF16() && !elementTy.isBF16()) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "only f32, f16, and bf16 element types are supported"); | ||
|
|
||
| bool descending; | ||
| if (!matchPattern(op.getDescending(), m_TorchConstantBool(&descending))) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "unimplemented: only constant descending value is supported"); | ||
|
|
||
| int64_t dim; | ||
| if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "unimplemented: only constant dim value is supported"); | ||
|
|
||
| int64_t rank = selfTy.getRank(); | ||
| if (rank == 0) { | ||
| if (dim != 0 && dim != -1) | ||
| return rewriter.notifyMatchFailure(op, "scalar sort dim is invalid"); | ||
|
|
||
| auto indicesTy = cast<RankedTensorType>( | ||
| getTypeConverter()->convertType(op.getResult(1).getType())); | ||
| Value indices = tosa::getConstTensor<int32_t>(rewriter, op, 0, {}).value(); | ||
| if (indicesTy.getElementType().isInteger(64)) | ||
| indices = tosa::CastOp::create(rewriter, loc, indicesTy, indices); | ||
| else if (indices.getType() != indicesTy) | ||
| indices = tensor::CastOp::create(rewriter, loc, indicesTy, indices); | ||
| rewriter.replaceOp(op, {self, indices}); | ||
| return success(); |
There was a problem hiding this comment.
The scalar sort fast path is after the float-only dtype check, so rank-zero integer are rejected before reaching the branch that just returns. This path does not need a numeric comparison. Can we move the rank-zero branch before the element-type gate?
| auto parsePrefixSlice = [&](AtenSliceTensorOp slice, | ||
| int64_t &sliceK) -> bool { | ||
| int64_t sliceDim; | ||
| if (!matchPattern(slice.getDim(), m_TorchConstantInt(&sliceDim))) | ||
| return false; | ||
| sliceDim = toPositiveDim(sliceDim, rank); | ||
| if (sliceDim != dim) | ||
| return false; | ||
| int64_t start; | ||
| if (!matchPattern(slice.getStart(), m_TorchConstantInt(&start)) || | ||
| start != 0) | ||
| return false; | ||
| int64_t step; | ||
| if (!matchPattern(slice.getStep(), m_TorchConstantInt(&step)) || | ||
| step != 1) | ||
| return false; | ||
| if (!matchPattern(slice.getEnd(), m_TorchConstantInt(&sliceK))) | ||
| return false; | ||
| return sliceK > 0 && sliceK <= dimSize; | ||
| }; |
There was a problem hiding this comment.
topk(k=0) is valid in PyTorch and should return empty values plus empty int64 indices, but this lowering misses that case. After decomposition, topk becomes sort followed by slices ending at 0, and parsePrefixSlice rejects that because it requires sliceK > 0. As a result, large inputs can fall back to full sort and hit the 128-element cap, while smaller inputs still run into the zero-sized-output rejection. Can we add a direct k == 0 fast path + coverage.
| if (dimInt < 0) | ||
| dimInt += rank; | ||
| if (dimInt < 0 || dimInt >= rank) | ||
| return failure(); |
There was a problem hiding this comment.
Second if always True if first is True, correct?
| @export | ||
| @annotate_args([None, ([1, 6], torch.float32, True)]) | ||
| def forward(self, x): | ||
| return torch.ops.aten.topk(x, k=6, dim=-1, largest=False, sorted=True) |
There was a problem hiding this comment.
This NaN/Inf smallest-topk test uses k=6 on a length-6 input, so it mostly exercises full ordering rather than partial top-k selection. Can we change this test or add a second one with k < dim_size so the NaN-aware top-k behavior is actually covered?
| SmallVector<mlir::Complex<APFloat>> values; | ||
| for (auto i : llvm::seq<unsigned>(0, matrixType.getDimSize(0))) { | ||
| for (auto j : llvm::seq<unsigned>(0, matrixType.getDimSize(1))) { | ||
| double v = scale * i * j; | ||
| double realV = cos(v); | ||
| double imagV = -sin(v); | ||
|
|
||
| bool unused; | ||
| APFloat real(realV); | ||
| real.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven, | ||
| &unused); | ||
| APFloat imag(imagV); | ||
| imag.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven, | ||
| &unused); | ||
|
|
||
| values.push_back(mlir::Complex<APFloat>(real, imag)); |
There was a problem hiding this comment.
Why we need these changes?
| if (dimInt < 0) | ||
| dimInt += operandType.getSizes().size(); | ||
| if (dimAttribute) { | ||
| int64_t rank = operandType.getSizes().size(); |
There was a problem hiding this comment.
Could you please explain why old logic was wrong and new is correct?
Summary
Adds TorchToTosa lowering support for
torch.aten.sort.The lowering supports statically-shaped ranked tensors with floating-point element types,
constant
dim, and constantdescending. It emits repeatedtosa.argmax/tosa.gatherselections, masks selected elements with a sentinel, and returns both sorted values and
indices.
This also handles the common decomposed
topkpattern wheresortresults are immediatelyprefix-sliced, lowering only the requested prefix instead of sorting the full dimension.
Details
aten.sortfor ascending and descending order.transposing back.
i64indices when required by the Torch result type.0.AtenSortOpfolding so invalid static dims do not fold incorrectly.invalid-dim no-fold behavior.