Skip to content

Commit

Permalink
[jax2tf] Remove non-native serialization test from jax_to_ir_test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 683124315
  • Loading branch information
gnecula authored and Google-ML-Automation committed Oct 7, 2024
1 parent 95631a7 commit 5fabd34
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 10 deletions.
2 changes: 1 addition & 1 deletion jax/tools/jax_to_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def ordered_wrapper(*args):
raise ValueError(
'Conversion to TF graph requires TensorFlow to be installed.')

f = jax2tf.convert(ordered_wrapper, native_serialization=False)
f = jax2tf.convert(ordered_wrapper)
f = tf_wrap_with_input_names(f, input_shapes)
f = tf.function(f, autograph=False)
g = f.get_concrete_function(*args).graph.as_graph_def()
Expand Down
9 changes: 0 additions & 9 deletions tests/jax_to_ir_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,13 @@ def test_parse_shape_str_invalid(self):
jax_to_ir.parse_shape_str('foo[]')

@unittest.skipIf(tf is None, 'TensorFlow not installed.')
@jtu.ignore_warning(
category=UserWarning,
message='jax2tf.convert with native_serialization=False is deprecated.'
)
def test_jax_to_tf_axpy(self):
tf_proto, tf_text = jax_to_ir.jax_to_tf(axpy, [
('y', jax_to_ir.parse_shape_str('f32[128]')),
('a', jax_to_ir.parse_shape_str('f32[]')),
('x', jax_to_ir.parse_shape_str('f32[128,2]')),
])

# Check that tf debug txt contains a broadcast, add, and multiply.
self.assertIn('BroadcastTo', tf_text)
self.assertIn('AddV2', tf_text)
self.assertIn('Mul', tf_text)

# Check that we can re-import our graphdef.
gdef = tf.compat.v1.GraphDef()
gdef.ParseFromString(tf_proto)
Expand Down

0 comments on commit 5fabd34

Please sign in to comment.