Skip to content

[TORCH] Add Kullback-Leibler divergence loss support #4204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9452,6 +9452,32 @@ def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [
}];
}

def Torch_AtenKlDivOp : Torch_Op<"aten.kl_div", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::kl_div : (Tensor, Tensor, int, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$target,
Torch_IntType:$reduction,
Torch_BoolType:$log_target
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenKlDivOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenKlDivOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
33 changes: 33 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10688,6 +10688,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.kl_div\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: Invalid reduction value.\"\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %int2 = torch.constant.int 2\n"
" %0 = torch.prim.Uninitialized : !torch.list<int>\n"
" %1 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
" %3 = func.call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" torch.prim.If.yield %3 : !torch.list<int>\n"
" } else {\n"
" %3 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = torch.aten.__contains__.int_list %3, %arg2 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.list<int>) {\n"
" %6 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" torch.prim.If.yield %6 : !torch.list<int>\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield %0 : !torch.list<int>\n"
" }\n"
" torch.prim.If.yield %5 : !torch.list<int>\n"
" }\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>>\n"
Expand Down Expand Up @@ -14517,6 +14542,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %int3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.kl_div\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.mse_loss\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
Expand Down
79 changes: 79 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10471,6 +10471,84 @@ class DecomposeAtenNllLossForwardOp
};
} // namespace

namespace {
class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenKlDivOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
Value target = op.getTarget();
Value reductionValue = op.getReduction();
Value logTargetValue = op.getLogTarget();

auto selfTy = cast<ValueTensorType>(self.getType());
auto targetTy = cast<ValueTensorType>(target.getType());
auto outTy = cast<ValueTensorType>(op.getType());

if (!selfTy.hasSizes() || !targetTy.hasSizes() || !outTy.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "require self, target and output having sizes!");
}

if (!selfTy.hasDtype() || !targetTy.hasDtype() || !outTy.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "require self, target and output having dtype!");
}

// Extract boolean value from logTarget argument
bool logTargetBool;
if (!matchPattern(logTargetValue, m_TorchConstantBool(&logTargetBool)))
return rewriter.notifyMatchFailure(
op, "Expected a constant boolean value for logTargetBool");

// Default: target tensor is not in log space
Value logOfTarget;
if (!logTargetBool) {
logOfTarget = rewriter.create<AtenLogOp>(loc, targetTy, target);
} else {
logOfTarget = target;
}

Value constOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value subValue = rewriter.create<AtenSubTensorOp>(loc, selfTy, logOfTarget,
self, constOne);

// if target tensor is already in log space
if (logTargetBool) {
target = rewriter.create<AtenExpOp>(loc, targetTy, target);
}
Value lossPointwise =
rewriter.create<AtenMulTensorOp>(loc, targetTy, target, subValue);

// Extract reduction int value from reduction argument
int64_t reduction;
if (!matchPattern(reductionValue, m_TorchConstantInt(&reduction))) {
return rewriter.notifyMatchFailure(op,
"reduction should be a constant int!");
}

Value loss;
Value none = rewriter.create<ConstantNoneOp>(loc);
// reduction: mean
if (reduction == 1) {
loss = rewriter.create<AtenMeanOp>(loc, outTy, lossPointwise, none);
} else if (reduction == 2) {
// reduction: sum
loss = rewriter.create<AtenSumOp>(loc, outTy, lossPointwise, none);
} else {
// reduction: none
loss = lossPointwise;
}

rewriter.replaceOp(op, loss);

return success();
}
};
} // namespace

namespace {
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
Expand Down Expand Up @@ -12386,6 +12464,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenNllLossForwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenKlDivOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArgsortOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenFlipudOp>();
target.addIllegalOp<AtenLogaddexpOp>();
target.addIllegalOp<AtenLogaddexp2Op>();
target.addIllegalOp<AtenKlDivOp>();

for (auto &opName : backendLegalOpsSet) {
target.addLegalOp(
Expand Down
15 changes: 15 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
"AtenSymConstrainRange_basic",
"AtenSymConstrainRangeForSize_basic",
"Aten_AssertScalar_basic",
# RuntimeError: attribute lookup is not defined on builtin:
"KlDivLossModule_batchmean_reduction_basic",
}

if torch_version_for_comparison() < version.parse("2.5.0.dev"):
Expand Down Expand Up @@ -386,6 +388,12 @@
"MaxPool3dStaticModule_basic",
# Looks like incorrect fx graph conversion
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
# error: failed to legalize operation 'torch.aten.xlogy.Tensor'
"KlDivLossModule_default_basic",
"KlDivLossModule_reduction_is_none_basic",
"KlDivLossModule_mean_reduction_basic",
"KlDivLossModule_sum_reduction_basic",
"KlDivLossModule_batchmean_reduction_basic",
}

FX_IMPORTER_XFAIL_SET = {
Expand Down Expand Up @@ -3068,6 +3076,7 @@
"NllLossStaticModule_mean_basic",
"NllLossModule_sum_basic",
"NllLossStaticModule_sum_basic",
"KlDivLossModule_batchmean_reduction_basic",
"NormScalarComplexModule_basic",
"NormScalarModule_basic",
"NormScalarOptDimKeepDimComplexModule_basic",
Expand Down Expand Up @@ -3953,6 +3962,12 @@
"NllLossStaticModule_mean_basic",
"NllLossStaticModule_sum_basic",
"NllLossStaticModule_weight_basic",
"KlDivLossModule_default_basic",
"KlDivLossModule_reduction_is_none_basic",
"KlDivLossModule_reduction_is_none_log_target_is_true_basic",
"KlDivLossModule_mean_reduction_basic",
"KlDivLossModule_sum_reduction_basic",
"KlDivLossModule_batchmean_reduction_basic",
"Exp2StaticModule_basic",
"ElementwiseRreluWithNoiseEvalModule_basic",
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2171,6 +2171,14 @@ def aten〇tril_indices〡shape(row: int, col: int, offset: int = 0, dtype: Opti
def aten〇deg2rad〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇kl_div〡shape(self: List[int], target: List[int], reduction: int = 1, log_target: bool = False) -> List[int]:
if reduction == 0:
return upstream_shape_functions.unary(self)
elif reduction in [1, 2]:
return []
else:
assert False, "Invalid reduction value."

@check_shape_function([
Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case.
Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim.
Expand Down Expand Up @@ -4523,6 +4531,14 @@ def aten〇_int_mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tu
assert mat2_dtype == torch.int8
return torch.int32

def aten〇kl_div〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1, log_target: bool = False) -> int:
self_rank, self_dtype = self_rank_dtype
target_rank, target_dtype = target_rank_dtype
ranks: List[Optional[int]] = [self_rank, target_rank]
dtypes = [self_dtype, target_dtype]
promoted_dtype = promote_dtypes(ranks, dtypes)
return promoted_dtype

@check_dtype_function(_check_two_tensor_op(
output_error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}))
def aten〇mse_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)"
)
emit("aten::kl_div : (Tensor, Tensor, int, bool) -> (Tensor)")
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,4 @@ def register_all_tests():
from . import gridsampler
from . import meshgrid
from . import timeout
from . import kl_div_loss
Loading
Loading