Skip to content

cast_bool_to_primitive breaks TensorFlow graph support. #5815

@AndreasMadsen

Description

@AndreasMadsen

🐛 Bug

To reproduce

import transformers
bert = tf.function(transformers.TFBertForMaskedLM.from_pretrained('bert-base-uncased'))

for i in range(2):
    (_, hidden_state) = bert(tf.constant([[10,11,12]]), output_hidden_states=True)
    print(f'computed {i}')

Errors with

ValueError: not enough values to unpack (expected 2, got 1)

Expected behavior

computed 1
computed 2

Same result as if tf.function was not used.

Environment info

Example environment : https://colab.research.google.com/gist/AndreasMadsen/593df94a3319dee58bba33a26efedeb3/untitled6.ipynb

  • transformers version: 3.0.2
  • Platform: Linux-4.19.104+-x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.6.9
  • PyTorch version (GPU?): 1.5.1+cu101 (False)
  • Tensorflow version (GPU?): 2.2.0 (False)
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Details

The bug happens due to cast_bool_to_primitive, that was introduced in 6e603cb. Before that, it was possible to get the hidden_states from Bert in TensorFlow graph/function mode.

Generally speaking, casting TensorFlow tensors to primitives is not a good practice, as it only works in eager mode. It is also completely unnecessary in this case, as using if bool_tensor_scalar: works perfectly fine.

def print_bool(x):
   if x:
       print('True')
   else:
       print('False')

print_bool_graph = tf.function(print_bool)

print('eager:')
print_bool(True) # Prints True
print_bool(False) # Prints False
print_bool(tf.constant(True)) # Prints True
print_bool(tf.constant(False)) # Prints False

print('')
print('graph:')
print_bool_graph(True) # Prints True
print_bool_graph(False) # Prints False
print_bool_graph(tf.constant(True)) # Prints True
print_bool_graph(tf.constant(False)) # Prints False

I can see there are some cases where defaults are used. The right way to handle that is to implement the default handling upstream in the first call() method. A lesser way would be to implement it as:

def cast_bool_to_primitive(x, default_value=False):
  if x is None:
    return default_value
  return x

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions