Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Folder and Canonicalizer for PrimsConvertElementTypeOp and AtenMaxPool2dWithIndicesOp #3272

Merged
merged 5 commits into from
May 2, 2024

Conversation

zezhang
Copy link
Collaborator

@zezhang zezhang commented May 1, 2024

While playing with TorchDynamo on ResNet18. I notice following issues:

  • prims.convert_element_type can’t be canonicalized even if the input and the output share the same type

  • aten.max_pool2d_with_indices is always used instead of aten.max_pool2d, even if the second returned output (indices) has no user

This PR fixes above issues by adding a folder to the PrimsConvertElementTypeOp and a canonicalizer to the AtenMaxPool2dWithIndicesOp

Lit test:

cmake --build build --target check-torch-mlir-all

@zezhang zezhang force-pushed the zezhang/dynamo_resnet18 branch from b467625 to 23b54de Compare May 1, 2024 22:47
@zezhang zezhang requested a review from sjain-stanford May 1, 2024 22:47
@zezhang
Copy link
Collaborator Author

zezhang commented May 2, 2024

Looks like adding a folder for PrimsConvertElementTypeOp makes the linalg PrimsConvertElementTypeModule_basic test fail.

repro: python -m e2e_testing.main --config=linalg -v -f 'PrimsConvertElementTypeModule_basic'

Because the input and output type could be a simple torch.tensor (no element dtype is available). We will have to use the dtype argument of the op to specify the output type. Any idea on this? @sjain-stanford

@sjain-stanford
Copy link
Member

Because the input and output type could be a simple torch.tensor (no element dtype is available). We will have to use the dtype argument of the op to specify the output type.

I think you want to prevent it from being folded when element dtype is unknown. I haven't validated it but something like this should work:

OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
  auto inputType = cast<BaseTensorType>(getA().getType());
  auto outputType = cast<BaseTensorType>(getResult().getType());
  if (inputType != outputType)
    return nullptr;
  if (!inputType.hasDtype() || !outputType.hasDtype())
    return nullptr;
  if (inputType.getDtype() != outputType.getDtype())
    return nullptr;
  return getA();
}

@sjain-stanford
Copy link
Member

sjain-stanford commented May 2, 2024

Good news is your PR fixes 4 tests which previously failed. You may update the xfail set.

  ****** Unexpectedly Passed tests - 4 tests
      XPASS - "MaxPool2dEmptyStrideStaticModule_basic"
      XPASS - "MaxPool2dStaticCeilModeTrueModule_basic"
      XPASS - "MaxPool2dStaticModule_basic"
      XPASS - "ResNet18StaticModule_basic"

Copy link
Member

@sjain-stanford sjain-stanford left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to land once CI is clean.

@zezhang zezhang merged commit 11cd7cd into llvm:main May 2, 2024
3 checks passed
@zezhang zezhang deleted the zezhang/dynamo_resnet18 branch May 2, 2024 07:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants