Skip to content

TFGPT2LMHeadModel fp16 support #8559

Closed
Closed
@mymusise

Description

@mymusise

Environment info

  • transformers version:
  • Platform: ubuntu 18.04
  • Python version: python3.8
  • Tensorflow version (GPU?): tf-nightly==2.5
  • Using GPU in script?: Y
  • Using distributed or parallel set-up in script?: N

Who can help

albert, bert, GPT2, XLM: @LysandreJik
Text Generation: @patrickvonplaten @TevenLeScao
tensorflow: @jplu

Information

Hi, there. If I want to use the mixed precision setting with keras apis when training TFGPT2LMHeadModel, like this:

policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)

Then I will got this error:

  File "/home/mymusise/pro/fast-gpt2/env/lib/python3.8/site-packages/transformers/modeling_tf_gpt2.py", line 154, in call
    attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions, training=training)
  File "/home/mymusise/pro/fast-gpt2/env/lib/python3.8/site-packages/transformers/modeling_tf_gpt2.py", line 101, in _attn
    w = w / tf.math.sqrt(dk)
  File "/home/mymusise/pro/fast-gpt2/env/lib/python3.8/site-packages/tensorflow/python/ops/math_ops.py", line 1181, in binary_op_wrapper
    raise e
  File "/home/mymusise/pro/fast-gpt2/env/lib/python3.8/site-packages/tensorflow/python/ops/math_ops.py", line 1165, in binary_op_wrapper
    return func(x, y, name=name)
  File "/home/mymusise/pro/fast-gpt2/env/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py", line 206, in wrapper
    return target(*args, **kwargs)
  File "/home/mymusise/pro/fast-gpt2/env/lib/python3.8/site-packages/tensorflow/python/ops/math_ops.py", line 1337, in truediv
    return _truediv_python3(x, y, name)
  File "/home/mymusise/pro/fast-gpt2/env/lib/python3.8/site-packages/tensorflow/python/ops/math_ops.py", line 1267, in _truediv_python3
    raise TypeError("x and y must have the same dtype, got %r != %r" %
TypeError: x and y must have the same dtype, got tf.float16 != tf.float32

Here's a example to reappear this.

Please help me guys.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions