forked from llvm/torch-mlir
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a canonicalization pattern for
aten.unflatten.int
(llvm#3656)
Addresses an issue in <llvm#3651> where some unflatten ops generated from onnx models weren't propagating static shape information. It may be necessary to add further optimizations for the more general case when some static information is present in the unflatten (or possibly reshape/view) op's `sizes` list, but not reflected in the output shape. These ops will only successfully infer shapes if the `sizes` list is gotten from a list of constant ints (with possibly one -1). A common example where this fails is when some of the `sizes` are determined from `aten.size.int` ops on dynamic tensors, and other `sizes` are known statically. This PR includes: - a canonicalizer for `aten.unflatten.int` which converts to `aten.unsqueeze` when it is expanding one dim to two, and one of the new dims is statically 1. - an improvement to the folder for `aten.__or__.bool` which does not rely on *both* operands being static.
- Loading branch information
Showing
3 changed files
with
92 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters