Skip to content

tf.keras.mixed_precision.experimental.Policy #2005

Closed
@ydshieh

Description

@ydshieh

❓ Questions & Help

I want to use mixed_precision, and I found tf.keras.mixed_precision.experimental.Policy.

So I put tf.keras.mixed_precision.experimental.set_policy("mixed_float16") before TFBertModel.from_pretrained(pretrained_weights). When I run the code, I got the following error:

InvalidArgumentError: cannot compute AddV2 as input #1(zero-based) was expected to be a half tensor but is a float tensor [Op:AddV2] name: tf_bert_model_1/bert/embeddings/add/

which happened at ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs.

I am not sure if I used it correctly. I think tf.keras.mixed_precision.experimental.set_policy is supposed to be used before constructing / build the model, as the tf page says Policies can be passed to the 'dtype' argument of layer constructors, or a global policy can be set with 'tf.keras.mixed_precision.experimental.set_policy'.

I wonder if I can use AMP with tf based transformer models and how. Thanks.

error.txt

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions