Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions lib/Conversion/TorchToSCF/TorchToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern<PrimLoopOp> {
targetType = Torch::IntType::get(op->getContext());
torchArg = typeConverter->materializeSourceConversion(
rewriter, scfWhileOp.getLoc(), targetType, {to});
} else if (auto tty = dyn_cast<RankedTensorType>(targetType)) {
targetType = op.getIterArgsInit()[barg.index()].getType();
torchArg = typeConverter->materializeSourceConversion(
rewriter, scfWhileOp.getLoc(), targetType, {to});
}
if (!torchArg)
return rewriter.notifyMatchFailure(op,
Expand All @@ -173,14 +177,6 @@ class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern<PrimLoopOp> {
"unsupported type of the operand");
loopConditionIterArgs.push_back(shouldContinue);
for (auto torchArg : primLoopConditionOp.getIterArgs()) {
Type torchType = torchArg.getType();

// If the argument is a torch tensor, directly add it in the list of
// iter args.
if (isa<Torch::BaseTensorType>(torchType)) {
loopConditionIterArgs.push_back(torchArg);
continue;
}
Value arg = typeConverter->materializeTargetConversion(
rewriter, scfWhileOp->getLoc(),
typeConverter->convertType(torchArg.getType()), {torchArg});
Expand Down
33 changes: 33 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,36 @@ def TorchPrimLoopForLikeTensorArgModule_basic(module, tu: TestUtils):
x_test = torch.zeros([7, 9]).float()

module.forward(x_test)


# ==============================================================================


class TorchPrimLoopWhileLikeHOPModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([7, 9], torch.float32, True),
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
from torch._higher_order_ops.while_loop import while_loop

def body_fn(i, x):
return i + 1, x + 1

i0 = torch.tensor(0)

out_i, out_x = while_loop(lambda i, x: i < 3, body_fn, (i0, x))
return out_i, out_x


@register_test_case(module_factory=lambda: TorchPrimLoopWhileLikeHOPModule())
def TorchPrimLoopWhileLikeHOPModule_basic(module, tu: TestUtils):
x_test = torch.zeros([7, 9]).float()

module.forward(x_test)
Loading
Loading