-
Notifications
You must be signed in to change notification settings - Fork 30.7k
Description
🐛 Bug
To reproduce
import transformers
bert = tf.function(transformers.TFBertForMaskedLM.from_pretrained('bert-base-uncased'))
for i in range(2):
(_, hidden_state) = bert(tf.constant([[10,11,12]]), output_hidden_states=True)
print(f'computed {i}')
Errors with
ValueError: not enough values to unpack (expected 2, got 1)
Expected behavior
computed 1
computed 2
Same result as if tf.function
was not used.
Environment info
Example environment : https://colab.research.google.com/gist/AndreasMadsen/593df94a3319dee58bba33a26efedeb3/untitled6.ipynb
transformers
version: 3.0.2- Platform: Linux-4.19.104+-x86_64-with-Ubuntu-18.04-bionic
- Python version: 3.6.9
- PyTorch version (GPU?): 1.5.1+cu101 (False)
- Tensorflow version (GPU?): 2.2.0 (False)
- Using GPU in script?:
- Using distributed or parallel set-up in script?:
Details
The bug happens due to cast_bool_to_primitive
, that was introduced in 6e603cb. Before that, it was possible to get the hidden_states
from Bert in TensorFlow graph/function mode.
Generally speaking, casting TensorFlow tensors to primitives is not a good practice, as it only works in eager mode. It is also completely unnecessary in this case, as using if bool_tensor_scalar:
works perfectly fine.
def print_bool(x):
if x:
print('True')
else:
print('False')
print_bool_graph = tf.function(print_bool)
print('eager:')
print_bool(True) # Prints True
print_bool(False) # Prints False
print_bool(tf.constant(True)) # Prints True
print_bool(tf.constant(False)) # Prints False
print('')
print('graph:')
print_bool_graph(True) # Prints True
print_bool_graph(False) # Prints False
print_bool_graph(tf.constant(True)) # Prints True
print_bool_graph(tf.constant(False)) # Prints False
I can see there are some cases where defaults are used. The right way to handle that is to implement the default handling upstream in the first call()
method. A lesser way would be to implement it as:
def cast_bool_to_primitive(x, default_value=False):
if x is None:
return default_value
return x