-
Notifications
You must be signed in to change notification settings - Fork 632
Improve list index normalization SimplifyShapeCalculations #710
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -21,6 +21,10 @@ namespace Torch { | |||
int64_t toPositiveDim(int64_t dim, int64_t inputRank); | |||
bool isValidDim(int64_t dim, int64_t inputRank); | |||
bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems); | |||
/// Returns the dimension indicated by `v` for a list of given `rank`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency (and I also think it reads a little better), I think it would be nice to make this a matcher, so that instead of
if (!matchPattern(setItem.idx(), m_TorchConstantInt(&index)))
return failure();
We instead do
if (!matchPattern(setItem.idx(), m_LegalConstantIndexIntoListOfSize(&index, runningList.size())))
return failure();
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is a separate discussion if we want to migrate some of these
if (!matchPattern(value, m_Something(&thing)))
return failure();
to
auto valueOpt = matchSomething(value);
if (!valueOpt)
return failure();
or
int64_t thing;
if (!matchSomething(&thing))
return failure()
// (kind of like getListConstructElements)
as a matter of style. Right now I want to avoid having too many ways of matching in the codebase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, on thinking about it more, I'll need to do a pass through the codebase to make this consistant anyways because of getListConstructElements, so for now, I'm fine moving forward with this patch, calling the function matchLegalConstantIndexIntoListOfSize. Sorry for all the back and forth.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've renamed the function as you suggested.
I understand the discussion on consistency. I'm happy to contribute to this. For matchLegalConstantIndexIntoListOfSize
specifically there were more places in the codebase with the same tests, but with different messages produced for the different reasons it could fail. I think it would be useful to think if we'd want to keep the ability to generate different failure messages. I'd appreciate if this can indeed be done in a new PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might want to update the comment too -- it seems to be written from the perspective of a shape, but the utility is really about a general list.
a5506b2
to
bc23d9f
Compare
@@ -21,6 +21,10 @@ namespace Torch { | |||
int64_t toPositiveDim(int64_t dim, int64_t inputRank); | |||
bool isValidDim(int64_t dim, int64_t inputRank); | |||
bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems); | |||
/// Returns the dimension indicated by `v` for a list of given `rank`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might want to update the comment too -- it seems to be written from the perspective of a shape, but the utility is really about a general list.
The PR is failing for an unrelated reason, presumably an update to
|
You can go ahead and push. It is unrelated to your PR. |
The reified code to compute the shape of torch.aten.constant_pad_nd uses negative indices when setting list elements. This was not converted to a positive offset in one place in SimplifyShapeCalculations which prevented computation of the static shape.
bc23d9f
to
f2719a1
Compare
* edits to the dockerfile, modification of PR llvm#707 Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com> * edit of comments Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com> * edit of comments Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com> * gong's suggested changes Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com> * update Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com>
The reified code to compute the shape of
torch.aten.constant_pad_nd
uses negative indices when setting list elements. This was not converted to a positive offset in one place inSimplifyShapeCalculations
which prevented computation of the static shape.Additionally I added a utility function
getMatchedListDim()
to replace a number of equivalent pieces of code, as the same code was needed for the fix.