-
Notifications
You must be signed in to change notification settings - Fork 247
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
701 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# Copyright 2022 The OFA-Sys Team. | ||
# All rights reserved. | ||
# This source code is licensed under the Apache 2.0 license | ||
# found in the LICENSE file in the root directory. | ||
|
||
import logging | ||
import warnings | ||
import torch | ||
import numpy as np | ||
|
||
from data import data_utils | ||
from data.ofa_dataset import OFADataset | ||
|
||
logger = logging.getLogger(__name__) | ||
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) | ||
|
||
|
||
def collate(samples, pad_idx, eos_idx): | ||
if len(samples) == 0: | ||
return {} | ||
|
||
def merge(key): | ||
return data_utils.collate_tokens( | ||
[s[key] for s in samples], | ||
pad_idx, | ||
eos_idx=eos_idx, | ||
) | ||
|
||
src_tokens = merge("source") | ||
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples]) | ||
|
||
prev_output_tokens = None | ||
target = None | ||
if samples[0].get("target", None) is not None: | ||
target = merge("target") | ||
tgt_lengths = torch.LongTensor( | ||
[s["target"].ne(pad_idx).long().sum() for s in samples] | ||
) | ||
ntokens = tgt_lengths.sum().item() | ||
|
||
if samples[0].get("prev_output_tokens", None) is not None: | ||
prev_output_tokens = merge("prev_output_tokens") | ||
else: | ||
ntokens = src_lengths.sum().item() | ||
|
||
target_strs = np.array([s["target_str"] for s in samples]) | ||
|
||
batch = { | ||
"nsentences": len(samples), | ||
"ntokens": ntokens, | ||
"net_input": { | ||
"src_tokens": src_tokens, | ||
"src_lengths": src_lengths, | ||
"prev_output_tokens": prev_output_tokens | ||
}, | ||
"target": target, | ||
"target_strs": target_strs | ||
} | ||
|
||
return batch | ||
|
||
|
||
class SummaryDataset(OFADataset): | ||
def __init__( | ||
self, | ||
split, | ||
dataset, | ||
bpe, | ||
src_dict, | ||
tgt_dict=None, | ||
code_dict_size=8192, | ||
num_bins=1000, | ||
max_src_length=512, | ||
max_tgt_length=128, | ||
noise_ratio=0.0 | ||
): | ||
super().__init__(split, dataset, bpe, src_dict, tgt_dict) | ||
self.max_src_length = max_src_length | ||
self.max_tgt_length = max_tgt_length | ||
self.code_dict_size = code_dict_size | ||
self.num_bins = num_bins | ||
self.noise_ratio = noise_ratio | ||
|
||
def __getitem__(self, index): | ||
source, target = self.dataset[index] | ||
target_str = target.lower() | ||
|
||
source = self.pre_caption(source, max_words=self.max_src_length) | ||
target = self.pre_caption(target, max_words=self.max_tgt_length) | ||
source = source.replace('<unk>', 'unk') | ||
target = target.replace('<unk>', 'unk') | ||
|
||
src_item = self.encode_text( | ||
' what is the summary of article " {} "?'.format(source), | ||
length=self.max_src_length | ||
) | ||
tgt_item = self.encode_text(' {}'.format(target)) | ||
noise_tgt_item = self.add_noise_to_tgt(tgt_item.clone(), self.noise_ratio) | ||
|
||
src_item = torch.cat([self.bos_item, src_item, self.eos_item]) | ||
target_item = torch.cat([tgt_item, self.eos_item]) | ||
prev_output_item = torch.cat([self.bos_item, noise_tgt_item]) | ||
|
||
example = { | ||
"source": src_item, | ||
"target": target_item, | ||
"prev_output_tokens": prev_output_item, | ||
"target_str": target_str | ||
} | ||
return example | ||
|
||
def add_noise_to_tgt(self, target, p): | ||
noise_indices = torch.FloatTensor(target.size(0)).uniform_() < p | ||
target[noise_indices] = torch.randint( | ||
4, len(self.src_dict) - self.code_dict_size - self.num_bins, size=(noise_indices.sum(),) | ||
) | ||
return target | ||
|
||
def collater(self, samples, pad_to_length=None): | ||
"""Merge a list of samples to form a mini-batch. | ||
Args: | ||
samples (List[dict]): samples to collate | ||
Returns: | ||
dict: a mini-batch with the following keys: | ||
""" | ||
return collate(samples, pad_idx=self.pad, eos_idx=self.eos) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import datasets | ||
import sys | ||
import json | ||
|
||
rouge = datasets.load_metric('../../utils/rouge.py') | ||
|
||
if __name__ == "__main__": | ||
f = sys.argv[1] | ||
results = json.load(open(f)) | ||
predictions = [result['hyp'] for result in results] | ||
references = [result['ref'] for result in results] | ||
results = rouge.compute(predictions=predictions, references=references, use_stemmer=True) | ||
print("Rouge1: ", results["rouge1"].mid.fmeasure) | ||
print("Rouge2: ", results["rouge2"].mid.fmeasure) | ||
print("RougeL: ", results["rougeL"].mid.fmeasure) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#!/usr/bin/env bash | ||
|
||
# The port for communication. Note that if you want to run multiple tasks on the same machine, | ||
# you need to specify different port numbers. | ||
export MASTER_PORT=2081 | ||
export CUDA_VISIBLE_DEVICES=4,5,6,7 | ||
export GPUS_PER_NODE=4 | ||
|
||
user_dir=../../ofa_module | ||
bpe_dir=../../utils/BPE | ||
|
||
data=../../dataset/gigaword_data/gigaword_test.tsv | ||
path=../../checkpoints/gigaword_large_best.pt | ||
result_path=../../results/gigaword | ||
selected_cols=0,1 | ||
split='test' | ||
|
||
python3 -m torch.distributed.launch --nproc_per_node=${GPUS_PER_NODE} --master_port=${MASTER_PORT} ../../evaluate.py \ | ||
${data} \ | ||
--path=${path} \ | ||
--user-dir=${user_dir} \ | ||
--task=gigaword \ | ||
--batch-size=32 \ | ||
--log-format=simple --log-interval=10 \ | ||
--seed=7 \ | ||
--gen-subset=${split} \ | ||
--results-path=${result_path} \ | ||
--beam=6 \ | ||
--lenpen=0.7 \ | ||
--max-len-b=32 \ | ||
--no-repeat-ngram-size=3 \ | ||
--fp16 \ | ||
--num-workers=0 \ | ||
--model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"selected_cols\":\"${selected_cols}\"}" | ||
|
||
python3 eval_rouge.py ${result_path}/test_predict.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
#!/usr/bin/env | ||
|
||
# The port for communication. Note that if you want to run multiple tasks on the same machine, | ||
# you need to specify different port numbers. | ||
export MASTER_PORT=2051 | ||
export CUDA_VISIBLE_DEVICES=0,1,2,3 | ||
export GPUS_PER_NODE=4 | ||
|
||
log_dir=./logs | ||
save_dir=./checkpoints | ||
mkdir -p $log_dir $save_dir | ||
|
||
bpe_dir=../../utils/BPE | ||
user_dir=../../ofa_module | ||
|
||
data_dir=../../dataset/gigaword_data | ||
data=${data_dir}/gigaword_train.tsv,${data_dir}/gigaword_dev.tsv | ||
restore_file=../../checkpoints/ofa_large.pt | ||
selected_cols=0,1 | ||
|
||
task=gigaword | ||
arch=ofa_large | ||
criterion=adjust_label_smoothed_cross_entropy | ||
label_smoothing=0.1 | ||
lr=5e-5 | ||
max_epoch=6 | ||
warmup_ratio=0.06 | ||
batch_size=64 | ||
update_freq=2 | ||
resnet_drop_path_rate=0.0 | ||
encoder_drop_path_rate=0.1 | ||
decoder_drop_path_rate=0.1 | ||
dropout=0.1 | ||
attention_dropout=0.0 | ||
max_src_length=512 | ||
max_tgt_length=64 | ||
num_bins=1000 | ||
|
||
for max_epoch in {6,}; do | ||
echo "max_epoch "${max_epoch} | ||
for lr in {1e-4,}; do | ||
echo "lr "${lr} | ||
for noise_ratio in {0.2,}; do | ||
echo "noise_ratio "${noise_ratio} | ||
|
||
log_file=${log_dir}/${max_epoch}"_"${lr}"_"${noise_ratio}".log" | ||
save_path=${save_dir}/${max_epoch}"_"${lr}"_"${noise_ratio} | ||
mkdir -p $save_path | ||
|
||
python3 -m torch.distributed.launch --nproc_per_node=${GPUS_PER_NODE} --master_port=${MASTER_PORT} ../../train.py \ | ||
$data \ | ||
--selected-cols=${selected_cols} \ | ||
--bpe-dir=${bpe_dir} \ | ||
--user-dir=${user_dir} \ | ||
--restore-file=${restore_file} \ | ||
--reset-optimizer --reset-dataloader --reset-meters \ | ||
--save-dir=${save_path} \ | ||
--task=${task} \ | ||
--arch=${arch} \ | ||
--criterion=${criterion} \ | ||
--label-smoothing=${label_smoothing} \ | ||
--batch-size=${batch_size} \ | ||
--update-freq=${update_freq} \ | ||
--encoder-normalize-before \ | ||
--decoder-normalize-before \ | ||
--share-decoder-input-output-embed \ | ||
--share-all-embeddings \ | ||
--layernorm-embedding \ | ||
--patch-layernorm-embedding \ | ||
--code-layernorm-embedding \ | ||
--resnet-drop-path-rate=${resnet_drop_path_rate} \ | ||
--encoder-drop-path-rate=${encoder_drop_path_rate} \ | ||
--decoder-drop-path-rate=${decoder_drop_path_rate} \ | ||
--dropout=${dropout} \ | ||
--attention-dropout=${attention_dropout} \ | ||
--weight-decay=0.01 --optimizer=adam --adam-betas="(0.9,0.999)" --adam-eps=1e-08 --clip-norm=1.0 \ | ||
--lr-scheduler=polynomial_decay --lr=${lr} \ | ||
--max-epoch=${max_epoch} --warmup-ratio=${warmup_ratio} \ | ||
--log-format=simple --log-interval=10 \ | ||
--fixed-validation-seed=7 \ | ||
--no-epoch-checkpoints --keep-best-checkpoints=1 \ | ||
--save-interval=1 --validate-interval=1 \ | ||
--save-interval-updates=2500 --validate-interval-updates=2500 \ | ||
--best-checkpoint-metric=rougeL_f1 --maximize-best-checkpoint-metric \ | ||
--max-src-length=${max_src_length} \ | ||
--max-tgt-length=${max_tgt_length} \ | ||
--find-unused-parameters \ | ||
--eval-rouge \ | ||
--eval-print-samples \ | ||
--eval-args='{"beam":6,"lenpen":0.7,"max_len_b":32,"no_repeat_ngram_size":3}' \ | ||
--add-type-embedding \ | ||
--scale-attn \ | ||
--scale-fc \ | ||
--scale-heads \ | ||
--disable-entangle \ | ||
--num-bins=${num_bins} \ | ||
--noise-ratio=${noise_ratio} \ | ||
--fp16 \ | ||
--fp16-scale-window=512 \ | ||
--num-workers=0 > ${log_file} 2>&1 | ||
done | ||
done | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .mm_tasks import * | ||
from .cv_tasks import * | ||
from .mm_tasks import * | ||
from .nlg_tasks import * | ||
from .nlu_tasks import * | ||
from .ofa_task import OFATask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .gigaword import GigawordTask |
Oops, something went wrong.