Skip to content

Commit

Permalink
Merge pull request OFA-Sys#79 from yangapku/feature/vqa
Browse files Browse the repository at this point in the history
VQA: add option for beam-search validation during fine-tuning
  • Loading branch information
yangapku authored Apr 24, 2022
2 parents ab0b0b4 + 5afe8d1 commit eb82a1c
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 57 deletions.
5 changes: 5 additions & 0 deletions run_scripts/vqa/train_vqa_base_distributed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ ema_fp32="--ema-fp32"
ema_decay=0.9999
ema_start_update=0

# Specify the inference type in validation after each fine-tuning epoch
# As mentioned in the readme, you can choose from allcand or beamsearch evaluation, default to allcand
val_inference_type=allcand

for max_epoch in {15,}; do
echo "max_epoch "${max_epoch}
for warmup_ratio in {0.04,}; do
Expand Down Expand Up @@ -134,6 +138,7 @@ for max_epoch in {15,}; do
${ema_fp32} \
--ema-decay=${ema_decay} \
--ema-start-update=${ema_start_update} \
--val-inference-type=${val_inference_type} \
--num-workers=0 > ${log_file} 2>&1
done
done
Expand Down
5 changes: 5 additions & 0 deletions run_scripts/vqa/train_vqa_distributed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ ema_fp32="--ema-fp32"
ema_decay=0.9999
ema_start_update=0

# Specify the inference type in validation after each fine-tuning epoch
# As mentioned in the readme, you can choose from allcand or beamsearch evaluation, default to allcand
val_inference_type=allcand

for total_num_updates in {40000,}; do
echo "total_num_updates "${total_num_updates}
for warmup_updates in {1000,}; do
Expand Down Expand Up @@ -135,6 +139,7 @@ for total_num_updates in {40000,}; do
${ema_fp32} \
--ema-decay=${ema_decay} \
--ema-start-update=${ema_start_update} \
--val-inference-type=${val_inference_type} \
--num-workers=0 > ${log_file} 2>&1
done
done
Expand Down
164 changes: 107 additions & 57 deletions tasks/mm_tasks/vqa_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@
logger = logging.getLogger(__name__)


def get_symbols_to_strip_from_output(generator):
if hasattr(generator, "symbols_to_strip_from_output"):
return generator.symbols_to_strip_from_output
else:
return {generator.bos, generator.eos}


def decode_fn(x, tgt_dict, bpe, generator, tokenizer=None):
x = tgt_dict.string(x.int().cpu(), extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator))
if bpe is not None:
x = bpe.decode(x)
if tokenizer is not None:
x = tokenizer.decode(x)
return x


@dataclass
class VqaGenConfig(OFAConfig):
max_object_length: int = field(
Expand Down Expand Up @@ -56,6 +72,16 @@ class VqaGenConfig(OFAConfig):
default=False,
metadata={"help": "whether to use ema"},
)
val_inference_type: Optional[str] = field(
default='allcand',
metadata={"help": "inference type in validation (allcand or beamsearch), default to allcand"},
)
eval_args: Optional[str] = field(
default='{"beam":5,"unnormalized":true,"temperature":1.0}',
metadata={
"help": 'generation args as JSON string for inference, only activated when --val-inference-type=beamsearch'
},
)


@register_task("vqa_gen", dataclass=VqaGenConfig)
Expand All @@ -71,6 +97,9 @@ def __init__(self, cfg: VqaGenConfig, src_dict, tgt_dict):

self.uses_ema = self.cfg.uses_ema

assert self.cfg.val_inference_type in ["allcand", "beamsearch"], \
"Unknown inference type encountered: {}, should be allcand or beamsearch.".format(self.cfg.val_inference_type)

def load_dataset(self, split, epoch=1, combine=False, **kwargs):
paths = self.cfg.data.split(',')
assert len(paths) > 0
Expand Down Expand Up @@ -121,11 +150,19 @@ def build_model(self, cfg):
constraint_mask[i][constraint_nodes] = True
constraint_mask_list.append(constraint_mask)

self.valid_answers_list = []
self.valid_constraint_masks_list = []
for i in range(0, len(answer_item_list), self.cfg.valid_batch_size):
self.valid_answers_list += [answer_item_list[i:i+self.cfg.valid_batch_size]]
self.valid_constraint_masks_list += [constraint_mask_list[i:i+self.cfg.valid_batch_size]]
if self.cfg.val_inference_type == "allcand":
self.valid_answers_list = []
self.valid_constraint_masks_list = []
for i in range(0, len(answer_item_list), self.cfg.valid_batch_size):
self.valid_answers_list += [answer_item_list[i:i+self.cfg.valid_batch_size]]
self.valid_constraint_masks_list += [constraint_mask_list[i:i+self.cfg.valid_batch_size]]
elif self.cfg.val_inference_type == "beamsearch":
gen_args = json.loads(self.cfg.eval_args)
self.generator = self.build_generator(
[model], Namespace(**gen_args)
)
else:
raise NotImplementedError("Error: Unknown inference type encountered.")

return model

Expand All @@ -149,58 +186,71 @@ def valid_step(self, sample, model, criterion, **extra_kwargs):

eval_model.eval()
with torch.no_grad():
encoder_out = eval_model.encoder(
sample["net_input"]["src_tokens"],
src_lengths=sample["net_input"]["src_lengths"],
patch_images=sample["net_input"]["patch_images"],
patch_masks=sample["net_input"]["patch_masks"]
)
device = sample["net_input"]["src_tokens"].device
eos_item = torch.tensor([self.src_dict.eos()])
pad = self.src_dict.pad()
valid_result = []
for valid_answers, valid_constraint_masks in zip(self.valid_answers_list, self.valid_constraint_masks_list):
valid_size = len(valid_answers)
valid_tgt_items = [
torch.cat([torch.tensor(decoder_prompt[1:]), valid_answer, eos_item])
for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
]
valid_prev_items = [
torch.cat([torch.tensor(decoder_prompt), valid_answer])
for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
]
valid_constraint_mask_items = [
torch.cat([torch.zeros(len(decoder_prompt)-1, valid_constraint_mask.size(1)).bool(), valid_constraint_mask], dim=0)
for decoder_prompt in sample["decoder_prompts"] for valid_constraint_mask in valid_constraint_masks
]
valid_tgt = data_utils.collate_tokens(valid_tgt_items, pad_idx=pad, left_pad=False).to(device)
valid_prev_output = data_utils.collate_tokens(valid_prev_items, pad_idx=pad, left_pad=False).to(device)
valid_constraint_masks = data_utils.collate_tokens(valid_constraint_mask_items, pad_idx=pad, left_pad=False).to(device)

new_encoder_out = {}
new_encoder_out["encoder_out"] = [
encoder_out["encoder_out"][0].repeat_interleave(valid_size, dim=1)
]
new_encoder_out["encoder_padding_mask"] = [
encoder_out["encoder_padding_mask"][0].repeat_interleave(valid_size, dim=0)
]
new_encoder_out["position_embeddings"] = [
encoder_out["position_embeddings"][0].repeat_interleave(valid_size, dim=0)
]

decoder_out = eval_model.decoder(valid_prev_output, encoder_out=new_encoder_out)
decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf)
lprobs = eval_model.get_normalized_probs(decoder_out, log_probs=True)
scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1)
scores = scores.masked_fill(valid_tgt.eq(self.tgt_dict.pad()), 0)
scores = scores.masked_fill((~valid_constraint_masks).all(2), 0)
scores = scores.sum(1)
scores = scores.view(-1, valid_size)
valid_result.append(scores)

valid_result = torch.cat(valid_result, dim=-1)
predicts = valid_result.argmax(1).tolist()
hyps = [self.index2ans[predict_index] for predict_index in predicts]
if self.cfg.val_inference_type == "allcand":
encoder_out = eval_model.encoder(
sample["net_input"]["src_tokens"],
src_lengths=sample["net_input"]["src_lengths"],
patch_images=sample["net_input"]["patch_images"],
patch_masks=sample["net_input"]["patch_masks"]
)
device = sample["net_input"]["src_tokens"].device
eos_item = torch.tensor([self.src_dict.eos()])
pad = self.src_dict.pad()
valid_result = []
for valid_answers, valid_constraint_masks in zip(self.valid_answers_list, self.valid_constraint_masks_list):
valid_size = len(valid_answers)
valid_tgt_items = [
torch.cat([torch.tensor(decoder_prompt[1:]), valid_answer, eos_item])
for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
]
valid_prev_items = [
torch.cat([torch.tensor(decoder_prompt), valid_answer])
for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
]
valid_constraint_mask_items = [
torch.cat([torch.zeros(len(decoder_prompt)-1, valid_constraint_mask.size(1)).bool(), valid_constraint_mask], dim=0)
for decoder_prompt in sample["decoder_prompts"] for valid_constraint_mask in valid_constraint_masks
]
valid_tgt = data_utils.collate_tokens(valid_tgt_items, pad_idx=pad, left_pad=False).to(device)
valid_prev_output = data_utils.collate_tokens(valid_prev_items, pad_idx=pad, left_pad=False).to(device)
valid_constraint_masks = data_utils.collate_tokens(valid_constraint_mask_items, pad_idx=pad, left_pad=False).to(device)

new_encoder_out = {}
new_encoder_out["encoder_out"] = [
encoder_out["encoder_out"][0].repeat_interleave(valid_size, dim=1)
]
new_encoder_out["encoder_padding_mask"] = [
encoder_out["encoder_padding_mask"][0].repeat_interleave(valid_size, dim=0)
]
new_encoder_out["position_embeddings"] = [
encoder_out["position_embeddings"][0].repeat_interleave(valid_size, dim=0)
]

decoder_out = eval_model.decoder(valid_prev_output, encoder_out=new_encoder_out)
decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf)
lprobs = eval_model.get_normalized_probs(decoder_out, log_probs=True)
scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1)
scores = scores.masked_fill(valid_tgt.eq(self.tgt_dict.pad()), 0)
scores = scores.masked_fill((~valid_constraint_masks).all(2), 0)
scores = scores.sum(1)
scores = scores.view(-1, valid_size)
valid_result.append(scores)

valid_result = torch.cat(valid_result, dim=-1)
predicts = valid_result.argmax(1).tolist()
hyps = [self.index2ans[predict_index] for predict_index in predicts]

elif self.cfg.val_inference_type == "beamsearch":
raw_hyps = self.inference_step(self.generator, [eval_model], sample, prefix_tokens=sample['prefix_tokens'])
hyps = []
for i, sample_id in enumerate(sample["id"].tolist()):
prefix_len = sample['prefix_tokens'][i].ne(1).sum().item()
detok_hypo_str = decode_fn(raw_hyps[i][0]["tokens"][prefix_len:], self.tgt_dict, self.bpe, self.generator)
hyps.append(detok_hypo_str.strip())

else:
raise NotImplementedError("Error: Unknown inference type encountered.")

scores = [ref_dict.get(hyp, 0) for ref_dict, hyp in zip(sample['ref_dict'], hyps)]
logging_output["_vqa_score_sum"] = sum(scores)
logging_output["_vqa_cnt"] = len(scores)
Expand Down

0 comments on commit eb82a1c

Please sign in to comment.