-
Notifications
You must be signed in to change notification settings - Fork 20
Add check for second value in sum: Logsumexp #90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
y = torch.sum(torch.log(torch.exp(x)), dim=1) | ||
y = torch.sum(torch.log(torch.exp(x)), dim=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to the order of calls to log(sum(exp()))
as we discussed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought these were false test cases. Will update it
torchfix/visitors/misc/__init__.py
Outdated
if ( | ||
self.get_specific_arg( | ||
node.args[0].value, arg_name="dim", arg_pos=1 | ||
).value.value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You only check for value of the argument when it's present.
It should be two nested if's - first if it's present, then if value is not None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first "if" condition on line 187 checks for the presence of "dim" and then if confirmed it is moved to the second "if" in line 190.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see now.
You should assign the arg to a variable to reduce code duplication.
And then add assert
that it's not None
, otherwise MyPy is complaining:
https://github.com/pytorch-labs/torchfix/actions/runs/13081681874/job/36506448560?pr=90
See this for example https://github.com/pytorch-labs/torchfix/blob/main/torchfix/visitors/deprecated_symbols/__init__.py#L35
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may also need to use ensure_type
, like here: https://github.com/pytorch-labs/torchfix/blob/main/torchfix/visitors/deprecated_symbols/qr.py#L19
Please run mypy locally and verify it passes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran the code and mypy errors are only resolved when I do a isinstance()
check.
And in doing so I need to check for both Integer
and Tuple
since the dim
can have both of the types as its argument value and not of type Name
which is there when None
value is there.
And since value of tuple cannot be retrieved through .value
. I need to update the code to different code structure that could handle both integer and tuple values for dim
which are Integers and are not None
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the code and have made the necessary changes to handle any future type based issues.
Updated the test cases as well
torchfix/visitors/misc/__init__.py
Outdated
node.args[0].value, arg_name="dim", arg_pos=1 | ||
) | ||
if dim_arg: # checks if dim argument is present | ||
if isinstance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These lines are redundant, no?
Later there are checks for cst.Integer
and cst.Tuple
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Removed it since they were test code.
torchfix/visitors/misc/__init__.py
Outdated
dim_arg.value, cst.Tuple | ||
): # checks if dim argument is an integer or tuple | ||
if ( | ||
isinstance(dim_arg.value, cst.Integer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cst.Integer
can not be None, meaningless condition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here the condition checks if the value is of type integer and also makes sure that the value it holds is also not None since Tuples in dim
cannot have None values
y = torch.log(torch.sum(torch.exp(x)), dim=None) #dim is not part of the sum fuction call and dim is None | ||
y = torch.log(torch.sum(torch.exp(x), keepdim=True, dim=None)) #dim argument cannot be None | ||
y = torch.log(torch.sum(torch.exp(x), dim=(1,None))) #dim argument cannot be a tuple with None | ||
y = torch.log(torch.sum(torch.exp(x), dim=(None,None))) #dim argument cannot be a tuple with None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to check for dim=(None,None)
or dim=(1,None)
, it can not happen because if present dim
is an int or tuple of ints: https://pytorch.org/docs/stable/generated/torch.sum.html
closing this in favor of #91 |
Added conditions to check if second value exists in "sum" for Logsumexp update. If not then no change.
Files updated: