Skip to content

Use finetuned-BART large to do conditional generation #4144

Closed
@tuhinjubcse

Description

@tuhinjubcse

Hi

I am using a slightly old tag of ur repo where BART had run_bart_sum.py. I finetuned bart-large on a custom data set and want to do conditional generation

from transformers import BartTokenizer, BartForConditionalGeneration
import torch

model = BartForConditionalGeneration.from_pretrained('bart-large')
tokenizer = BartTokenizer.from_pretrained('bart-large')

ARTICLE_TO_SUMMARIZE = "President Donald Trump's senior adviser and son-in-law, Jared Kushner, praised the administration's response to the coronavirus pandemic as a \"great success story\" on Wednesday -- less than a day after the number of confirmed coronavirus cases in the United States topped 1 million. Kushner painted a rosy picture for \"Fox and Friends\" Wednesday morning, saying that \"the federal government rose to the challenge and this is a great success story and I think that that's really what needs to be told.\""


# model = BartForConditionalGeneration.from_pretrained('./bart_sum/checkpointepoch=2.ckpt')
# tokenizer = BartTokenizer.from_pretrained('./bart_sum/checkpointepoch=2.ckpt')

model = BartForConditionalGeneration.from_pretrained('bart-large')
tokenizer = BartTokenizer.from_pretrained('bart-large')
state = torch.load('./bart_sum/checkpointepoch=2.ckpt',map_location='cpu')
model.load_state_dict(state)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()


inputs = tokenizer.batch_encode_plus(
    [ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
summary_ids = model.generate(
    inputs['input_ids'], num_beams=1, max_length=512, early_stopping=True)


print([tokenizer.decode(g, skip_special_tokens=True,
                        clean_up_tokenization_spaces=False)
       for g in summary_ids])


I tried both loading the finetuned checkpoint directly as well as loading bart-large and setting state dict

For former it gives me

Traceback (most recent call last):
  File "generate.py", line 10, in <module>
    model = BartForConditionalGeneration.from_pretrained('./bart_sum/checkpointepoch=2.ckpt')
  File "/datastor/Softwarez/miniconda3/lib/python3.7/site-packages/transformers/modeling_utils.py", line 438, in from_pretrained
    **kwargs,
  File "/datastor/Softwarez/miniconda3/lib/python3.7/site-packages/transformers/configuration_utils.py", line 200, in from_pretrained
    config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
  File "/datastor/Softwarez/miniconda3/lib/python3.7/site-packages/transformers/configuration_utils.py", line 252, in get_config_dict
    config_dict = cls._dict_from_json_file(resolved_config_file)
  File "/datastor/Softwarez/miniconda3/lib/python3.7/site-packages/transformers/configuration_utils.py", line 344, in _dict_from_json_file
    text = reader.read()
  File "/datastor/Softwarez/miniconda3/lib/python3.7/codecs.py", line 322, in decode
    (result, consumed) = self._buffer_decode(data, self.errors, final)
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 0: invalid start byte

For latter
Unexpected key(s) in state_dict: "epoch", "global_step", "checkpoint_callback_best", "optimizer_states", "lr_schedulers", "state_dict", "hparams", "hparams_type".

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions