Skip to content

Optimize while scans when only last state is needed #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 24, 2023

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Feb 9, 2023

Closes #178

TODO:

  • Check with more taps

@@ -677,7 +678,6 @@ def __init__(
typeConstructor: Optional[TensorConstructorType] = None,
truncate_gradient: int = -1,
name: Optional[str] = None,
as_while: bool = False,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used anywhere. The information is contained in info

@ricardoV94 ricardoV94 force-pushed the save_mem_scan_while_scan branch 2 times, most recently from d242d47 to d45e579 Compare February 10, 2023 17:52
@ricardoV94 ricardoV94 marked this pull request as ready for review February 10, 2023 17:52
@ricardoV94 ricardoV94 force-pushed the save_mem_scan_while_scan branch from d45e579 to f665170 Compare February 10, 2023 17:55
@ricardoV94 ricardoV94 changed the title Add special optimization for While Scan where only last state is used Optimize while scans when only last state is needed Feb 10, 2023
@ricardoV94 ricardoV94 force-pushed the save_mem_scan_while_scan branch from f665170 to 9e8663c Compare February 10, 2023 18:37
@codecov-commenter
Copy link

Codecov Report

Merging #216 (9e8663c) into main (5521d82) will increase coverage by 0.00%.
The diff coverage is 86.27%.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #216   +/-   ##
=======================================
  Coverage   80.43%   80.43%           
=======================================
  Files         170      170           
  Lines       45287    45322   +35     
  Branches     9677     9687   +10     
=======================================
+ Hits        36427    36457   +30     
- Misses       6639     6641    +2     
- Partials     2221     2224    +3     
Impacted Files Coverage Δ
pytensor/scan/op.py 84.60% <ø> (ø)
pytensor/tensor/subtensor.py 89.64% <ø> (ø)
pytensor/scan/rewriting.py 79.39% <85.41%> (+0.14%) ⬆️
pytensor/tensor/rewriting/subtensor.py 88.73% <100.00%> (+0.04%) ⬆️

@ricardoV94 ricardoV94 marked this pull request as draft February 12, 2023 22:17
@ricardoV94 ricardoV94 force-pushed the save_mem_scan_while_scan branch 3 times, most recently from 1f46fa7 to a9f8026 Compare February 14, 2023 08:38
@ricardoV94 ricardoV94 marked this pull request as ready for review February 14, 2023 08:45
@ricardoV94 ricardoV94 force-pushed the save_mem_scan_while_scan branch from a9f8026 to a6cee04 Compare February 14, 2023 10:12
@ricardoV94 ricardoV94 force-pushed the save_mem_scan_while_scan branch from 8bec593 to 51e161d Compare February 15, 2023 09:55
@ricardoV94
Copy link
Member Author

Need an approving review here. Some other work to speedup gradients that use Scans depends on this: #174

Copy link
Member

@Armavica Armavica left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks to your help on Slack I could read this much more closely now, so I have a few questions :)

continue

u = node.inputs[0]
if not (u.owner and isinstance(u.owner.op, Subtensor)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that u.owner is None?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for clients. u.owner can only be None for input variables, but since these are clients of the scan they must not be input variables by definition.

Still I like to always check anyway. That's what the u.owner and ... does. Were you suggesting I remove that first check?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that's what I thought, I wondered why you checked and couldn't think of a way that it could be None here. If it cannot be None here perhaps an assert would be more appropriate, to better convey the intent? It is just a tiny detail, so if you don't feel like changing, I don't mind if you don't :)

Copy link
Member Author

@ricardoV94 ricardoV94 Feb 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wrong, it can be None, because a Subtensor node can be a client of scan through an input other than the first (as an index, and not what is being indexed), even though that is unlikely.

Comment on lines +1423 to +1427
# Special case for recurrent outputs where only the last result
# is requested. This is needed for this rewrite to apply to
# do-while Scans at all. Otherwise, `get_canonical_form_slice` in
# the `else` branch would reintroduce a shape dependency on the
# original Scan which would lead this rewrite to abort in the end.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea what is happening here, but maybe this will be clearer when I understand the relationships between this rewrite and the new one…

Copy link
Member Author

@ricardoV94 ricardoV94 Feb 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as before, the old default branch would go through a get_canonical_form_slice slice, that will reference the length of x. Basically converting x[-1] to x[len(x)-1], which will cause a dependency on the outputs of the old while scan, leading the rewrite to abort in the end.

Here we add a special case for x[-1], to not do that.

@ricardoV94 ricardoV94 force-pushed the save_mem_scan_while_scan branch from 51e161d to 1d0d43c Compare February 22, 2023 13:26
@ricardoV94 ricardoV94 requested a review from Armavica February 22, 2023 13:26
Copy link
Member

@Armavica Armavica left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thank you, I think I understand everything now!

@ricardoV94 ricardoV94 merged commit 63f8d6e into pymc-devs:main Feb 24, 2023
@ricardoV94
Copy link
Member Author

Thanks for the review @Armavica!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Apply scan memory save rewrite to while scans
4 participants