Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions e2e_testing/torchscript/nll_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,61 @@ def forward(self, x, y):
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
def NllLossModule_ignore_index(module, tu: TestUtils):
module.forward(tu.rand(2, 3), torch.tensor([0, 1]))

class NllLossModule_backward(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1, -1], torch.float32, True),
([-1], torch.int64, True),
([], torch.float32, True),
])
def forward(self, grad_output, input, target, total_weight):
return torch.ops.aten.nll_loss_backward(grad_output=grad_output,
self=input,
target=target,
weight=None,
reduction=0,
ignore_index=10,
total_weight=total_weight)


@register_test_case(module_factory=lambda: NllLossModule_backward())
def NllLossModuleBackward_basic(module, tu: TestUtils):
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
torch.tensor(3.))


class NllLossModule_backward_ignore_index(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1, -1], torch.float32, True),
([-1], torch.int64, True),
([], torch.float32, True),
])
def forward(self, grad_output, input, target, total_weight):
return torch.ops.aten.nll_loss_backward(grad_output=grad_output,
self=input,
target=target,
weight=None,
reduction=0,
ignore_index=1,
total_weight=total_weight)


@register_test_case(
module_factory=lambda: NllLossModule_backward_ignore_index())
def NllLossModuleBackward_ignore_index(module, tu: TestUtils):
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
torch.tensor(3.))
20 changes: 20 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1852,6 +1852,26 @@ def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [
let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` qualified(type($self)) `,` qualified(type($target)) `,` qualified(type($weight)) `,` qualified(type($reduction)) `,` qualified(type($ignore_index)) `->` qualified(type($output)) `,` qualified(type($total_weight))";
}

def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$self,
AnyTorchTensorType:$target,
AnyTorchOptionalTensorType:$weight,
Torch_IntType:$reduction,
Torch_IntType:$ignore_index,
AnyTorchTensorType:$total_weight
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$grad_output `,` $self `,` $target `,` $weight `,` $reduction `,` $ignore_index `,` $total_weight attr-dict `:` qualified(type($grad_output)) `,` qualified(type($self)) `,` qualified(type($target)) `,` qualified(type($weight)) `,` qualified(type($reduction)) `,` qualified(type($ignore_index)) `,` qualified(type($total_weight)) `->` qualified(type($result))";
}

def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
AllowsTypeRefinement,
HasValueSemantics
Expand Down
9 changes: 9 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ struct ResultTypeState {
ScalarType result_type(const ResultTypeState &in_state);
ScalarType promote_skip_undefined(ScalarType a, ScalarType b);

//===----------------------------------------------------------------------===//
// These constants control the reduction behavior of the loss functions.
// None, Mean and Sum corresponds to "do not reduce", "Mean of losses", and "sum
// of losses" respectively.
// Source:
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/Reduction.h
//===----------------------------------------------------------------------===//
enum Reduction { None, Mean, Sum, END };

} // namespace torch_upstream
} // namespace torch
} // namespace mlir
106 changes: 106 additions & 0 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
Expand All @@ -28,6 +29,7 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;
using namespace mlir::torch::torch_upstream; // For ScalarType and type

// -----------------------------------------------------------------------------
// Patterns (as this grows, it should be organized into multiple files)
Expand Down Expand Up @@ -1323,6 +1325,108 @@ class ConvertAtenNllLossForwardOp
};
} // namespace

// Given `grad_output`, `input`, `target`, `nll_loss_backward` is given by:
// for i in range(0, len(input[0])):
// for j in range(0, len(input[1])):
// nll_loss_backward[i][j] = (j == target[i]) ? -grad_output[i] : 0
// TODO: `weight` and `reduction` operands are still to be taken care of.
namespace {
class ConvertAtenNllLossBackwardOp
: public OpConversionPattern<AtenNllLossBackwardOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenNllLossBackwardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
Value input = adaptor.self();
Value target = adaptor.target();
Value weight = adaptor.weight();
Value gradOutput = adaptor.grad_output();

int64_t reduction;
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction)))
return rewriter.notifyMatchFailure(op, "dim must be constant");

// TODO: Handle reduction.
if (reduction != Reduction::None)
return rewriter.notifyMatchFailure(
op, "reduction along dimensions is not supported.");

// TODO: Incorporate the weight argument.
if (!weight.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Unimplemented, the weight operand is not incorporated.");

Value ignoreIndex = adaptor.ignore_index();
Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex);

unsigned inputRank = input.getType().cast<RankedTensorType>().getRank();
unsigned targetRank = target.getType().cast<RankedTensorType>().getRank();

// TODO: Cases with targetRank != 1 where `Mean` or `Sum` reduction is
// required.
if (inputRank != 2 || targetRank != 1) {
return rewriter.notifyMatchFailure(
op, "expected input and target to be rank 2 and 1 respectively");
}
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();

Type elementType = resultType.getElementType();

// Given there is no reduction `grad_input` size is equal to `input` size.
auto outputSize = getTensorSizes(rewriter, loc, input);
Value initTensor0 =
createZeroInitTensor(rewriter, loc, outputSize, elementType);
Value zeroVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));

SmallVector<AffineExpr> targetExpr{rewriter.getAffineDimExpr(0)};
SmallVector<AffineExpr> resultExpr{rewriter.getAffineDimExpr(0),
rewriter.getAffineDimExpr(1)};
SmallVector<StringRef> iteratorTypes{getParallelIteratorTypeName(),
getParallelIteratorTypeName()};
auto indexingMaps =
AffineMap::inferFromExprList({targetExpr, targetExpr, resultExpr});
Value finalRes =
rewriter
.create<linalg::GenericOp>(
loc, resultType, ValueRange{target, gradOutput}, initTensor0,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value indTarget = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), args[0]);
Value indJ = rewriter.create<linalg::IndexOp>(loc, 1);

// The final result is given by:
// grad_input[i][j] = (j == target[i]) ? -grad_output[i] : 0
Value cmpEq = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, indJ, indTarget);

// The target index shouldn't be equal to `ignoreIndex`.
Value cmpNe = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, ignoreIndexVal, indTarget);
Value finalPredicate =
rewriter.create<arith::AndIOp>(loc, cmpEq, cmpNe);
Value negate =
rewriter.create<arith::NegFOp>(loc, elementType, args[1]);
Value selectFinal = rewriter.create<mlir::SelectOp>(
loc, finalPredicate, negate, zeroVal);
b.create<linalg::YieldOp>(loc, selectFinal);
})
.getResult(0);

rewriter.replaceOp(op, finalRes);
return success();
}
};
} // namespace

namespace {
// See comments at in convertMmOp and the heading for this section for general
// considerations. This function needs to be auto-generated.
Expand Down Expand Up @@ -4525,6 +4629,8 @@ class ConvertTorchToLinalg
patterns.add<ConvertAtenSliceTensorOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossBackwardOp>();
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
target.addIllegalOp<AtenIndexSelectOp>();
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
Expand Down
25 changes: 23 additions & 2 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
return visitBinaryScalarOp(op, operands);
} else if (auto nllForwardOp = dyn_cast<AtenNllLossForwardOp>(op)) {
return visitAtenNllLossForwardOp(nllForwardOp, operands);
} else if (auto nllBackwardOp = dyn_cast<AtenNllLossBackwardOp>(op)) {
return visitAtenNllLossBackwardOp(nllBackwardOp, operands);
} else if (auto nativeLayerNormOp = dyn_cast<AtenNativeLayerNormOp>(op)) {
return visitAtenNativeLayerNormOp(nativeLayerNormOp, operands);
} else if (auto constantPadNdOp = dyn_cast<AtenConstantPadNdOp>(op)) {
Expand Down Expand Up @@ -647,6 +649,9 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
ChangeResult visitAtenNllLossForwardOp(
AtenNllLossForwardOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitAtenNllLossBackwardOp(
AtenNllLossBackwardOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitAtenNativeLayerNormOp(
AtenNativeLayerNormOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
Expand Down Expand Up @@ -1188,8 +1193,8 @@ ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp(

if (self.hasSizes &&
matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) {
// reduction == 1 means reduce 1st dim.
resultRank = reduction == 1 ? resultRank - 1 : resultRank;
if (reduction != Reduction::None)
resultRank -= 1;
}
outputKnowledge.sizes.resize(resultRank - 1, kUnknownSize);
outputKnowledge.hasSizes = true;
Expand All @@ -1199,6 +1204,22 @@ ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp(
return resultLattice;
}

ChangeResult TypeAnalyzer::visitAtenNllLossBackwardOp(
AtenNllLossBackwardOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto self = operands[1]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());

knowledge.dtype = self.dtype;
if (self.hasSizes) {
unsigned resultRank = self.sizes.size();
knowledge.sizes.resize(resultRank, kUnknownSize);
knowledge.hasSizes = true;
}
return getLatticeElement(op.getResult()).join(knowledge);
}

ChangeResult TypeAnalyzer::visitAtenSqueezeDimOp(
AtenSqueezeDimOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto operand = operands[0]->getValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::std : (Tensor, bool) -> (Tensor)")
emit("aten::var : (Tensor, bool) -> (Tensor)")
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")

# Misc tensor ops.
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
Expand Down