-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create a separate internal helper function for XLA compilation #3852
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
XLA backends are written in C++, so method calls don't show up in Python profiling results from cProfile. Adding an explicit function call fixes that. This is helpful for interpretting profiling results, e.g., on the example from jax-ml#3847. Before: 70814996 function calls (69915267 primitive calls) in 112.804 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 1193 24.936 0.021 30.336 0.025 xla.py:227(xla_primitive_callable) 10524/1 16.342 0.002 112.991 112.991 xla.py:595(_xla_callable) 2014622/1843062 8.745 0.000 16.618 0.000 util.py:29(safe_map) 18145 3.662 0.000 4.218 0.000 source_info_util.py:27(user_frame) 196061/183909 1.604 0.000 24.647 0.000 partial_eval.py:150(default_process_primitive) 423499 1.569 0.000 1.569 0.000 {method 'reduce' of 'numpy.ufunc' objects} After: 71147652 function calls (70235594 primitive calls) in 101.718 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 1294 38.894 0.030 38.894 0.030 xla.py:325(_backend_compile) 2017790/1844559 6.965 0.000 14.139 0.000 util.py:29(safe_map) 18146 3.317 0.000 3.839 0.000 source_info_util.py:27(user_frame) 196226/184073 1.467 0.000 21.889 0.000 partial_eval.py:150(default_process_primitive) 423771 1.419 0.000 1.419 0.000 {method 'reduce' of 'numpy.ufunc' objects} We now clearly see that both `xla_primitive_callable` and `_xla_callable` are slow for the same reason and ~40 seconds is spent inside XLA compilation.
Note that this must be related constant-folding; it is not the time to compile the computation that is eventually produced, which is much shorter. Perhaps try with the omnistaging branch? |
hawkinsp
approved these changes
Jul 24, 2020
NeilGirdhar
pushed a commit
to NeilGirdhar/jax
that referenced
this pull request
Jul 24, 2020
…l#3852) XLA backends are written in C++, so method calls don't show up in Python profiling results from cProfile. Adding an explicit function call fixes that. This is helpful for interpretting profiling results, e.g., on the example from jax-ml#3847. Before: 70814996 function calls (69915267 primitive calls) in 112.804 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 1193 24.936 0.021 30.336 0.025 xla.py:227(xla_primitive_callable) 10524/1 16.342 0.002 112.991 112.991 xla.py:595(_xla_callable) 2014622/1843062 8.745 0.000 16.618 0.000 util.py:29(safe_map) 18145 3.662 0.000 4.218 0.000 source_info_util.py:27(user_frame) 196061/183909 1.604 0.000 24.647 0.000 partial_eval.py:150(default_process_primitive) 423499 1.569 0.000 1.569 0.000 {method 'reduce' of 'numpy.ufunc' objects} After: 71147652 function calls (70235594 primitive calls) in 101.718 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 1294 38.894 0.030 38.894 0.030 xla.py:325(_backend_compile) 2017790/1844559 6.965 0.000 14.139 0.000 util.py:29(safe_map) 18146 3.317 0.000 3.839 0.000 source_info_util.py:27(user_frame) 196226/184073 1.467 0.000 21.889 0.000 partial_eval.py:150(default_process_primitive) 423771 1.419 0.000 1.419 0.000 {method 'reduce' of 'numpy.ufunc' objects} We now clearly see that both `xla_primitive_callable` and `_xla_callable` are slow for the same reason and ~40 seconds is spent inside XLA compilation.
NeilGirdhar
pushed a commit
to NeilGirdhar/jax
that referenced
this pull request
Jul 24, 2020
…l#3852) XLA backends are written in C++, so method calls don't show up in Python profiling results from cProfile. Adding an explicit function call fixes that. This is helpful for interpretting profiling results, e.g., on the example from jax-ml#3847. Before: 70814996 function calls (69915267 primitive calls) in 112.804 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 1193 24.936 0.021 30.336 0.025 xla.py:227(xla_primitive_callable) 10524/1 16.342 0.002 112.991 112.991 xla.py:595(_xla_callable) 2014622/1843062 8.745 0.000 16.618 0.000 util.py:29(safe_map) 18145 3.662 0.000 4.218 0.000 source_info_util.py:27(user_frame) 196061/183909 1.604 0.000 24.647 0.000 partial_eval.py:150(default_process_primitive) 423499 1.569 0.000 1.569 0.000 {method 'reduce' of 'numpy.ufunc' objects} After: 71147652 function calls (70235594 primitive calls) in 101.718 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 1294 38.894 0.030 38.894 0.030 xla.py:325(_backend_compile) 2017790/1844559 6.965 0.000 14.139 0.000 util.py:29(safe_map) 18146 3.317 0.000 3.839 0.000 source_info_util.py:27(user_frame) 196226/184073 1.467 0.000 21.889 0.000 partial_eval.py:150(default_process_primitive) 423771 1.419 0.000 1.419 0.000 {method 'reduce' of 'numpy.ufunc' objects} We now clearly see that both `xla_primitive_callable` and `_xla_callable` are slow for the same reason and ~40 seconds is spent inside XLA compilation.
NeilGirdhar
pushed a commit
to NeilGirdhar/jax
that referenced
this pull request
Jul 24, 2020
…l#3852) XLA backends are written in C++, so method calls don't show up in Python profiling results from cProfile. Adding an explicit function call fixes that. This is helpful for interpretting profiling results, e.g., on the example from jax-ml#3847. Before: 70814996 function calls (69915267 primitive calls) in 112.804 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 1193 24.936 0.021 30.336 0.025 xla.py:227(xla_primitive_callable) 10524/1 16.342 0.002 112.991 112.991 xla.py:595(_xla_callable) 2014622/1843062 8.745 0.000 16.618 0.000 util.py:29(safe_map) 18145 3.662 0.000 4.218 0.000 source_info_util.py:27(user_frame) 196061/183909 1.604 0.000 24.647 0.000 partial_eval.py:150(default_process_primitive) 423499 1.569 0.000 1.569 0.000 {method 'reduce' of 'numpy.ufunc' objects} After: 71147652 function calls (70235594 primitive calls) in 101.718 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 1294 38.894 0.030 38.894 0.030 xla.py:325(_backend_compile) 2017790/1844559 6.965 0.000 14.139 0.000 util.py:29(safe_map) 18146 3.317 0.000 3.839 0.000 source_info_util.py:27(user_frame) 196226/184073 1.467 0.000 21.889 0.000 partial_eval.py:150(default_process_primitive) 423771 1.419 0.000 1.419 0.000 {method 'reduce' of 'numpy.ufunc' objects} We now clearly see that both `xla_primitive_callable` and `_xla_callable` are slow for the same reason and ~40 seconds is spent inside XLA compilation.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
XLA backends are written in C++, so method calls don't show up in Python
profiling results from cProfile. Adding an explicit function call fixes that.
This is helpful for interpreting profiling results, e.g., on the example from
#3847.
Before:
After:
We now clearly see that both
xla_primitive_callable
and_xla_callable
areslow for the same reason and ~40 seconds is spent inside XLA compilation.