Skip to content

Using FP16 on BartModel #3249

@AOZMH

Description

@AOZMH

🐛 Bug

Information

Model I am using (Bert, XLNet ...): BART

Language I am using the model on (English, Chinese ...): English

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: CNN/DM
  • my own task or dataset: (give details below)

To reproduce

I've installed the master branch of transformers but I still encountered the same issue as #3117 when using FP16 BartModel. I just initialized the model without loading the pretarined weights, but I guess the model should still be able to correctly forward the input LongTensor(batch, seq_length). The code is shown below, simply initialize a model and forward an input:

model = BartModel(BartConfig())
model = model.cuda().half()
cur_inputs = torch.zeros(4,16,dtype=torch.long).cuda()
cur_res = model(cur_inputs)

The error is:

~\Anaconda3\envs\pytorch\lib\site-packages\transformers\modeling_bart.py in forward(self, query, key, value, key_padding_mask, layer_state, need_weights, static_kv, attn_mask)
assert v is not None
--> attn_output = torch.bmm(attn_probs, v)
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
RuntimeError: Expected object of scalar type Float but got scalar type Half for argument #2 'mat2' in call to _th_bmm

@sshleifer The model is quite novel to me, so am I using it incorrectly or there's still a bug in BertModel class? Thanks in advance for the help!

Environment info

  • transformers version: master branch
  • Platform: Windows
  • Python version: 3.7.0
  • PyTorch version (GPU?): 1.4.0
  • Tensorflow version (GPU?): /
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions