Closed
Description
🚀 Feature request
I would like to use BART in FP16 mode, but it seems impossible for now :
config = BartConfig(vocab_size=50264, output_past=True)
model = AutoModelWithLMHead.from_pretrained('bart-large-cnn', config=config).cuda().half()
tokenizer = AutoTokenizer.from_pretrained('bart-large-cnn')
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
generated_ids = model.generate(inputs['input_ids'].cuda(), attention_mask=inputs['attention_mask'].cuda(), num_beams=4, max_length=5)
File "/data/user/.venv/bartqg/lib/python3.6/site-packages/transformers/modeling_bart.py", line 647, in forward
attn_output = torch.bmm(attn_probs, v)
RuntimeError: Expected object of scalar type Float but got scalar type Half for argument #2 'mat2' in call to _th_bmm
@sshleifer Do you plan to implement a FP16-friendly version of BART ?
Metadata
Assignees
Labels
No labels