Skip to content

Commit

Permalink
[JAX] Replace uses of jax.xla_computation() with jax.jit().lower().
Browse files Browse the repository at this point in the history
jax.xla_computation() is deprecated in favor of jax.jit(...).lower(...).

The most common replacements are either jax.jit(...).lower(...).compiler_ir(dialect='hlo') or jax.jit(...).lower(...).cost_analysis().

PiperOrigin-RevId: 509937560
  • Loading branch information
hawkinsp authored and copybara-github committed Feb 15, 2023
1 parent a8fc2ce commit e962caa
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions trax/supervised/trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,9 +451,9 @@ def save_computation_graphs(self):
if self.n_devices > 1:
batch = _reshape_by_device(batch, self.n_devices)
weights = self._opt_state.weights[0]
forward_computation = jax.xla_computation(self._model_predict_eval)(
forward_computation = jax.jit(self._model_predict_eval).lower(
batch, weights=weights, state=self._model_state[0],
rng=self._rngs[0])
rng=self._rngs[0]).compiler_ir(dialect='hlo')
with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.txt'), 'w') as f:
f.write(forward_computation.as_hlo_text())
with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.dot'), 'w') as f:
Expand Down

0 comments on commit e962caa

Please sign in to comment.