Description
❓ Questions & Help
I want to use mixed_precision
, and I found tf.keras.mixed_precision.experimental.Policy.
So I put tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
before TFBertModel.from_pretrained(pretrained_weights)
. When I run the code, I got the following error:
InvalidArgumentError: cannot compute AddV2 as input #1(zero-based) was expected to be a half tensor but is a float tensor [Op:AddV2] name: tf_bert_model_1/bert/embeddings/add/
which happened at ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs
.
I am not sure if I used it correctly. I think tf.keras.mixed_precision.experimental.set_policy
is supposed to be used before constructing / build the model, as the tf page says Policies can be passed to the 'dtype' argument of layer constructors, or a global policy can be set with 'tf.keras.mixed_precision.experimental.set_policy'
.
I wonder if I can use AMP with tf based transformer models and how. Thanks.