-
Notifications
You must be signed in to change notification settings - Fork 517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Folder and Canonicalizer for PrimsConvertElementTypeOp and AtenMaxPool2dWithIndicesOp #3272
Conversation
b467625
to
23b54de
Compare
Looks like adding a folder for repro: Because the input and output type could be a simple |
I think you want to prevent it from being folded when element dtype is unknown. I haven't validated it but something like this should work: OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
auto inputType = cast<BaseTensorType>(getA().getType());
auto outputType = cast<BaseTensorType>(getResult().getType());
if (inputType != outputType)
return nullptr;
if (!inputType.hasDtype() || !outputType.hasDtype())
return nullptr;
if (inputType.getDtype() != outputType.getDtype())
return nullptr;
return getA();
} |
Good news is your PR fixes 4 tests which previously failed. You may update the xfail set.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to land once CI is clean.
While playing with TorchDynamo on ResNet18. I notice following issues:
prims.convert_element_type
can’t be canonicalized even if the input and the output share the same typeaten.max_pool2d_with_indices
is always used instead ofaten.max_pool2d
, even if the second returned output (indices) has no userThis PR fixes above issues by adding a folder to the PrimsConvertElementTypeOp and a canonicalizer to the AtenMaxPool2dWithIndicesOp
Lit test:
cmake --build build --target check-torch-mlir-all