Closed
Description
In the test case test_layernorm_fc_1024_different_inputs_NTK_FLAT_layer_norm=C (__main__.StaxTest) test_layernorm_fc_1024_different_inputs_NTK_FLAT_layer_norm=C (__main__.StaxTest) test_layernorm_fc_1024_different_inputs_NTK_FLAT_layer_norm=C(model='fc', width=1024, same_inputs=False, is_ntk=True, proj_into_2d='FLAT', layer_norm='C')
:
tf_jax_stax.py", line 162, in <lambda>
init_fun = lambda rng, input_shape: (tfnp.zeros(input_shape), ())
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/tensorflow/python/ops/numpy_ops/np_array_ops.py", line 66, in zeros
return np_arrays.tensor_to_ndarray(array_ops.zeros(shape, dtype=dtype))
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py", line 201, in wrapper
return target(*args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py", line 2770, in wrapped
tensor = fun(*args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py", line 2828, in zeros
shape = ops.convert_to_tensor(shape, dtype=dtypes.int32)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 1526, in convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py", line 339, in _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py", line 264, in constant
return _constant_impl(value, dtype, shape, name, verify_shape=False,
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py", line 281, in _constant_impl
tensor_util.make_tensor_proto(
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/tensorflow/python/framework/tensor_util.py", line 457, in make_tensor_proto
_AssertCompatible(values, dtype)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/tensorflow/python/framework/tensor_util.py", line 336, in _AssertCompatible
raise TypeError("Expected %s, got %s of type '%s' instead." %
TypeError: Expected int32, got TensorSpec(shape=(), dtype=tf.int32, name=None) of type 'TensorSpec' instead.