Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a separate internal helper function for XLA compilation #3852

Merged
merged 1 commit into from
Jul 24, 2020

Conversation

shoyer
Copy link
Collaborator

@shoyer shoyer commented Jul 24, 2020

XLA backends are written in C++, so method calls don't show up in Python
profiling results from cProfile. Adding an explicit function call fixes that.

This is helpful for interpreting profiling results, e.g., on the example from
#3847.

Before:

         70814996 function calls (69915267 primitive calls) in 112.804 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1193   24.936    0.021   30.336    0.025 xla.py:227(xla_primitive_callable)
  10524/1   16.342    0.002  112.991  112.991 xla.py:595(_xla_callable)
2014622/1843062    8.745    0.000   16.618    0.000 util.py:29(safe_map)
    18145    3.662    0.000    4.218    0.000 source_info_util.py:27(user_frame)
196061/183909    1.604    0.000   24.647    0.000 partial_eval.py:150(default_process_primitive)
   423499    1.569    0.000    1.569    0.000 {method 'reduce' of 'numpy.ufunc' objects}

After:

         71147652 function calls (70235594 primitive calls) in 101.718 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1294   38.894    0.030   38.894    0.030 xla.py:325(_backend_compile)
2017790/1844559    6.965    0.000   14.139    0.000 util.py:29(safe_map)
    18146    3.317    0.000    3.839    0.000 source_info_util.py:27(user_frame)
196226/184073    1.467    0.000   21.889    0.000 partial_eval.py:150(default_process_primitive)
   423771    1.419    0.000    1.419    0.000 {method 'reduce' of 'numpy.ufunc' objects}

We now clearly see that both xla_primitive_callable and _xla_callable are
slow for the same reason and ~40 seconds is spent inside XLA compilation.

XLA backends are written in C++, so method calls don't show up in Python
profiling results from cProfile. Adding an explicit function call fixes that.

This is helpful for interpretting profiling results, e.g., on the example from
jax-ml#3847.

Before:

             70814996 function calls (69915267 primitive calls) in 112.804 seconds

       Ordered by: internal time

       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
         1193   24.936    0.021   30.336    0.025 xla.py:227(xla_primitive_callable)
      10524/1   16.342    0.002  112.991  112.991 xla.py:595(_xla_callable)
    2014622/1843062    8.745    0.000   16.618    0.000 util.py:29(safe_map)
        18145    3.662    0.000    4.218    0.000 source_info_util.py:27(user_frame)
    196061/183909    1.604    0.000   24.647    0.000 partial_eval.py:150(default_process_primitive)
       423499    1.569    0.000    1.569    0.000 {method 'reduce' of 'numpy.ufunc' objects}

After:

             71147652 function calls (70235594 primitive calls) in 101.718 seconds

       Ordered by: internal time

       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
         1294   38.894    0.030   38.894    0.030 xla.py:325(_backend_compile)
    2017790/1844559    6.965    0.000   14.139    0.000 util.py:29(safe_map)
        18146    3.317    0.000    3.839    0.000 source_info_util.py:27(user_frame)
    196226/184073    1.467    0.000   21.889    0.000 partial_eval.py:150(default_process_primitive)
       423771    1.419    0.000    1.419    0.000 {method 'reduce' of 'numpy.ufunc' objects}

We now clearly see that both `xla_primitive_callable` and `_xla_callable` are
slow for the same reason and ~40 seconds is spent inside XLA compilation.
@shoyer shoyer requested a review from hawkinsp July 24, 2020 17:11
@google-cla google-cla bot added the cla: yes label Jul 24, 2020
@hawkinsp
Copy link
Collaborator

Note that this must be related constant-folding; it is not the time to compile the computation that is eventually produced, which is much shorter.

Perhaps try with the omnistaging branch?

@hawkinsp hawkinsp merged commit b7bcfa6 into jax-ml:master Jul 24, 2020
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Jul 24, 2020
…l#3852)

XLA backends are written in C++, so method calls don't show up in Python
profiling results from cProfile. Adding an explicit function call fixes that.

This is helpful for interpretting profiling results, e.g., on the example from
jax-ml#3847.

Before:

             70814996 function calls (69915267 primitive calls) in 112.804 seconds

       Ordered by: internal time

       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
         1193   24.936    0.021   30.336    0.025 xla.py:227(xla_primitive_callable)
      10524/1   16.342    0.002  112.991  112.991 xla.py:595(_xla_callable)
    2014622/1843062    8.745    0.000   16.618    0.000 util.py:29(safe_map)
        18145    3.662    0.000    4.218    0.000 source_info_util.py:27(user_frame)
    196061/183909    1.604    0.000   24.647    0.000 partial_eval.py:150(default_process_primitive)
       423499    1.569    0.000    1.569    0.000 {method 'reduce' of 'numpy.ufunc' objects}

After:

             71147652 function calls (70235594 primitive calls) in 101.718 seconds

       Ordered by: internal time

       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
         1294   38.894    0.030   38.894    0.030 xla.py:325(_backend_compile)
    2017790/1844559    6.965    0.000   14.139    0.000 util.py:29(safe_map)
        18146    3.317    0.000    3.839    0.000 source_info_util.py:27(user_frame)
    196226/184073    1.467    0.000   21.889    0.000 partial_eval.py:150(default_process_primitive)
       423771    1.419    0.000    1.419    0.000 {method 'reduce' of 'numpy.ufunc' objects}

We now clearly see that both `xla_primitive_callable` and `_xla_callable` are
slow for the same reason and ~40 seconds is spent inside XLA compilation.
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Jul 24, 2020
…l#3852)

XLA backends are written in C++, so method calls don't show up in Python
profiling results from cProfile. Adding an explicit function call fixes that.

This is helpful for interpretting profiling results, e.g., on the example from
jax-ml#3847.

Before:

             70814996 function calls (69915267 primitive calls) in 112.804 seconds

       Ordered by: internal time

       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
         1193   24.936    0.021   30.336    0.025 xla.py:227(xla_primitive_callable)
      10524/1   16.342    0.002  112.991  112.991 xla.py:595(_xla_callable)
    2014622/1843062    8.745    0.000   16.618    0.000 util.py:29(safe_map)
        18145    3.662    0.000    4.218    0.000 source_info_util.py:27(user_frame)
    196061/183909    1.604    0.000   24.647    0.000 partial_eval.py:150(default_process_primitive)
       423499    1.569    0.000    1.569    0.000 {method 'reduce' of 'numpy.ufunc' objects}

After:

             71147652 function calls (70235594 primitive calls) in 101.718 seconds

       Ordered by: internal time

       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
         1294   38.894    0.030   38.894    0.030 xla.py:325(_backend_compile)
    2017790/1844559    6.965    0.000   14.139    0.000 util.py:29(safe_map)
        18146    3.317    0.000    3.839    0.000 source_info_util.py:27(user_frame)
    196226/184073    1.467    0.000   21.889    0.000 partial_eval.py:150(default_process_primitive)
       423771    1.419    0.000    1.419    0.000 {method 'reduce' of 'numpy.ufunc' objects}

We now clearly see that both `xla_primitive_callable` and `_xla_callable` are
slow for the same reason and ~40 seconds is spent inside XLA compilation.
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Jul 24, 2020
…l#3852)

XLA backends are written in C++, so method calls don't show up in Python
profiling results from cProfile. Adding an explicit function call fixes that.

This is helpful for interpretting profiling results, e.g., on the example from
jax-ml#3847.

Before:

             70814996 function calls (69915267 primitive calls) in 112.804 seconds

       Ordered by: internal time

       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
         1193   24.936    0.021   30.336    0.025 xla.py:227(xla_primitive_callable)
      10524/1   16.342    0.002  112.991  112.991 xla.py:595(_xla_callable)
    2014622/1843062    8.745    0.000   16.618    0.000 util.py:29(safe_map)
        18145    3.662    0.000    4.218    0.000 source_info_util.py:27(user_frame)
    196061/183909    1.604    0.000   24.647    0.000 partial_eval.py:150(default_process_primitive)
       423499    1.569    0.000    1.569    0.000 {method 'reduce' of 'numpy.ufunc' objects}

After:

             71147652 function calls (70235594 primitive calls) in 101.718 seconds

       Ordered by: internal time

       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
         1294   38.894    0.030   38.894    0.030 xla.py:325(_backend_compile)
    2017790/1844559    6.965    0.000   14.139    0.000 util.py:29(safe_map)
        18146    3.317    0.000    3.839    0.000 source_info_util.py:27(user_frame)
    196226/184073    1.467    0.000   21.889    0.000 partial_eval.py:150(default_process_primitive)
       423771    1.419    0.000    1.419    0.000 {method 'reduce' of 'numpy.ufunc' objects}

We now clearly see that both `xla_primitive_callable` and `_xla_callable` are
slow for the same reason and ~40 seconds is spent inside XLA compilation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants