-
Notifications
You must be signed in to change notification settings - Fork 29.7k
Description
🐛 Bug
Information
Model I am using (Bert, XLNet ...): BART
Language I am using the model on (English, Chinese ...): English
The problem arises when using:
- the official example scripts: (give details below)
- my own modified scripts: (give details below)
The tasks I am working on is:
- an official GLUE/SQUaD task: CNN/DM
- my own task or dataset: (give details below)
To reproduce
I've installed the master branch of transformers but I still encountered the same issue as #3117 when using FP16 BartModel. I just initialized the model without loading the pretarined weights, but I guess the model should still be able to correctly forward the input LongTensor(batch, seq_length). The code is shown below, simply initialize a model and forward an input:
model = BartModel(BartConfig())
model = model.cuda().half()
cur_inputs = torch.zeros(4,16,dtype=torch.long).cuda()
cur_res = model(cur_inputs)
The error is:
~\Anaconda3\envs\pytorch\lib\site-packages\transformers\modeling_bart.py in forward(self, query, key, value, key_padding_mask, layer_state, need_weights, static_kv, attn_mask)
assert v is not None
--> attn_output = torch.bmm(attn_probs, v)
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
RuntimeError: Expected object of scalar type Float but got scalar type Half for argument #2 'mat2' in call to _th_bmm
@sshleifer The model is quite novel to me, so am I using it incorrectly or there's still a bug in BertModel class? Thanks in advance for the help!
Environment info
transformers
version: master branch- Platform: Windows
- Python version: 3.7.0
- PyTorch version (GPU?): 1.4.0
- Tensorflow version (GPU?): /
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No