Skip to content

BART FP16 #3117

Closed
Closed
@astariul

Description

🚀 Feature request

I would like to use BART in FP16 mode, but it seems impossible for now :

config = BartConfig(vocab_size=50264, output_past=True)
model = AutoModelWithLMHead.from_pretrained('bart-large-cnn', config=config).cuda().half()
tokenizer = AutoTokenizer.from_pretrained('bart-large-cnn')
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
generated_ids = model.generate(inputs['input_ids'].cuda(), attention_mask=inputs['attention_mask'].cuda(), num_beams=4, max_length=5)

File "/data/user/.venv/bartqg/lib/python3.6/site-packages/transformers/modeling_bart.py", line 647, in forward
attn_output = torch.bmm(attn_probs, v)
RuntimeError: Expected object of scalar type Float but got scalar type Half for argument #2 'mat2' in call to _th_bmm

@sshleifer Do you plan to implement a FP16-friendly version of BART ?

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions