Closed
Description
As titled.
tf_jax_stax.py", line 96, in init_fun
return tfnp.zeros(output_shape), (W, b)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/trax/tf_numpy/numpy_impl/array_ops.py", line 83, in zeros
return arrays_lib.tensor_to_ndarray(tf.zeros(shape, dtype=dtype))
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py", line 2677, in wrapped
tensor = fun(*args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py", line 2733, in zeros
output = fill(shape, constant(zero, dtype=dtype), name=name)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py", line 234, in fill
result = gen_array_ops.fill(dims, value, name=name)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3316, in fill
_ops.raise_from_not_ok_status(e, name)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 6653, in raise_from_not_ok_status
six.raise_from(core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: Dimension -1 must be >= 0 [Op:Fill]