Closed
Description
Hi
I am using a slightly old tag of ur repo where BART had run_bart_sum.py. I finetuned bart-large on a custom data set and want to do conditional generation
from transformers import BartTokenizer, BartForConditionalGeneration
import torch
model = BartForConditionalGeneration.from_pretrained('bart-large')
tokenizer = BartTokenizer.from_pretrained('bart-large')
ARTICLE_TO_SUMMARIZE = "President Donald Trump's senior adviser and son-in-law, Jared Kushner, praised the administration's response to the coronavirus pandemic as a \"great success story\" on Wednesday -- less than a day after the number of confirmed coronavirus cases in the United States topped 1 million. Kushner painted a rosy picture for \"Fox and Friends\" Wednesday morning, saying that \"the federal government rose to the challenge and this is a great success story and I think that that's really what needs to be told.\""
# model = BartForConditionalGeneration.from_pretrained('./bart_sum/checkpointepoch=2.ckpt')
# tokenizer = BartTokenizer.from_pretrained('./bart_sum/checkpointepoch=2.ckpt')
model = BartForConditionalGeneration.from_pretrained('bart-large')
tokenizer = BartTokenizer.from_pretrained('bart-large')
state = torch.load('./bart_sum/checkpointepoch=2.ckpt',map_location='cpu')
model.load_state_dict(state)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
inputs = tokenizer.batch_encode_plus(
[ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
summary_ids = model.generate(
inputs['input_ids'], num_beams=1, max_length=512, early_stopping=True)
print([tokenizer.decode(g, skip_special_tokens=True,
clean_up_tokenization_spaces=False)
for g in summary_ids])
I tried both loading the finetuned checkpoint directly as well as loading bart-large and setting state dict
For former it gives me
Traceback (most recent call last):
File "generate.py", line 10, in <module>
model = BartForConditionalGeneration.from_pretrained('./bart_sum/checkpointepoch=2.ckpt')
File "/datastor/Softwarez/miniconda3/lib/python3.7/site-packages/transformers/modeling_utils.py", line 438, in from_pretrained
**kwargs,
File "/datastor/Softwarez/miniconda3/lib/python3.7/site-packages/transformers/configuration_utils.py", line 200, in from_pretrained
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
File "/datastor/Softwarez/miniconda3/lib/python3.7/site-packages/transformers/configuration_utils.py", line 252, in get_config_dict
config_dict = cls._dict_from_json_file(resolved_config_file)
File "/datastor/Softwarez/miniconda3/lib/python3.7/site-packages/transformers/configuration_utils.py", line 344, in _dict_from_json_file
text = reader.read()
File "/datastor/Softwarez/miniconda3/lib/python3.7/codecs.py", line 322, in decode
(result, consumed) = self._buffer_decode(data, self.errors, final)
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 0: invalid start byte
For latter
Unexpected key(s) in state_dict: "epoch", "global_step", "checkpoint_callback_best", "optimizer_states", "lr_schedulers", "state_dict", "hparams", "hparams_type".
Metadata
Metadata
Assignees
Labels
No labels