-
Notifications
You must be signed in to change notification settings - Fork 134
Optimize while scans when only last state is needed #216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize while scans when only last state is needed #216
Conversation
@@ -677,7 +678,6 @@ def __init__( | |||
typeConstructor: Optional[TensorConstructorType] = None, | |||
truncate_gradient: int = -1, | |||
name: Optional[str] = None, | |||
as_while: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not used anywhere. The information is contained in info
d242d47
to
d45e579
Compare
d45e579
to
f665170
Compare
f665170
to
9e8663c
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #216 +/- ##
=======================================
Coverage 80.43% 80.43%
=======================================
Files 170 170
Lines 45287 45322 +35
Branches 9677 9687 +10
=======================================
+ Hits 36427 36457 +30
- Misses 6639 6641 +2
- Partials 2221 2224 +3
|
1f46fa7
to
a9f8026
Compare
a9f8026
to
a6cee04
Compare
8bec593
to
51e161d
Compare
Need an approving review here. Some other work to speedup gradients that use Scans depends on this: #174 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks to your help on Slack I could read this much more closely now, so I have a few questions :)
pytensor/scan/rewriting.py
Outdated
continue | ||
|
||
u = node.inputs[0] | ||
if not (u.owner and isinstance(u.owner.op, Subtensor)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible that u.owner
is None
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not for clients. u.owner
can only be None
for input variables, but since these are clients of the scan they must not be input variables by definition.
Still I like to always check anyway. That's what the u.owner and ...
does. Were you suggesting I remove that first check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, that's what I thought, I wondered why you checked and couldn't think of a way that it could be None
here. If it cannot be None
here perhaps an assert
would be more appropriate, to better convey the intent? It is just a tiny detail, so if you don't feel like changing, I don't mind if you don't :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wrong, it can be None
, because a Subtensor node can be a client of scan through an input other than the first (as an index, and not what is being indexed), even though that is unlikely.
# Special case for recurrent outputs where only the last result | ||
# is requested. This is needed for this rewrite to apply to | ||
# do-while Scans at all. Otherwise, `get_canonical_form_slice` in | ||
# the `else` branch would reintroduce a shape dependency on the | ||
# original Scan which would lead this rewrite to abort in the end. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have no idea what is happening here, but maybe this will be clearer when I understand the relationships between this rewrite and the new one…
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue as before, the old default branch would go through a get_canonical_form_slice
slice, that will reference the length of x
. Basically converting x[-1]
to x[len(x)-1]
, which will cause a dependency on the outputs of the old while scan, leading the rewrite to abort in the end.
Here we add a special case for x[-1]
, to not do that.
51e161d
to
1d0d43c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok thank you, I think I understand everything now!
Thanks for the review @Armavica! |
Closes #178
TODO: