Skip to content

Commit

Permalink
Merge pull request #185 from OFA-Sys/features/ofa_cn
Browse files Browse the repository at this point in the history
add cn ofa
  • Loading branch information
JustinLin610 authored Jul 29, 2022
2 parents 2cc60da + 292ea26 commit b8c93fd
Show file tree
Hide file tree
Showing 9 changed files with 42,354 additions and 16 deletions.
7 changes: 6 additions & 1 deletion data/mm_data/caption_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def __init__(
transforms.Normalize(mean=mean, std=std),
])

if type(bpe).__name__ == 'GPT2BPE':
self.prompt = " what does the image describe?"
elif type(bpe).__name__ == 'BertBPE':
self.prompt = "图片描述了什么内容?"

def __getitem__(self, index):
uniq_id, image, caption = self.dataset[index]

Expand All @@ -128,7 +133,7 @@ def __getitem__(self, index):
caption = ' '.join(caption.strip().split())
caption_list = [cap.translate(self.transtab).strip() for cap in caption.strip().split('&&')]
tgt_caption = '&&'.join(caption_list)
src_item = self.encode_text(" what does the image describe?")
src_item = self.encode_text(self.prompt)
tgt_item = self.encode_text(" {}".format(tgt_caption))

src_item = torch.cat([self.bos_item, src_item, self.eos_item])
Expand Down
7 changes: 6 additions & 1 deletion data/mm_data/refcoco_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ def __init__(
T.Normalize(mean=mean, std=std, max_image_size=max_image_size)
])

if type(bpe).__name__ == 'GPT2BPE':
self.prompt = ' which region does the text " {} " describe?'
elif type(bpe).__name__ == 'BertBPE':
self.prompt = '这段文字" {} "描述的是哪个区域?'

def __getitem__(self, index):
uniq_id, base64_str, text, region_coord = self.dataset[index]

Expand All @@ -139,7 +144,7 @@ def __getitem__(self, index):
quant_y1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][3] * (self.num_bins - 1)).round()))
region_coord = "{} {} {} {}".format(quant_x0, quant_y0, quant_x1, quant_y1)
src_caption = self.pre_caption(text, self.max_src_length)
src_item = self.encode_text(' which region does the text " {} " describe?'.format(src_caption))
src_item = self.encode_text(self.prompt.format(src_caption))
tgt_item = self.encode_text(region_coord, use_bpe=False)

src_item = torch.cat([self.bos_item, src_item, self.eos_item])
Expand Down
9 changes: 7 additions & 2 deletions data/nlg_data/summary_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def __init__(
self.num_bins = num_bins
self.noise_ratio = noise_ratio

if type(bpe).__name__ == 'GPT2BPE':
self.prompt = ' what is the summary of article " {} "?'
elif type(bpe).__name__ == 'BertBPE':
self.prompt = "{} 请用一个句子简单总结上文:"

def __getitem__(self, index):
source, target = self.dataset[index]
target_str = target.lower()
Expand All @@ -91,10 +96,10 @@ def __getitem__(self, index):
target = target.replace('<unk>', 'unk')

src_item = self.encode_text(
' what is the summary of article " {} "?'.format(source),
self.prompt.format(source),
length=self.max_src_length
)
tgt_item = self.encode_text(' {}'.format(target))
tgt_item = self.encode_text('{}'.format(target))
noise_tgt_item = self.add_noise_to_tgt(tgt_item.clone(), self.noise_ratio)

src_item = torch.cat([self.bos_item, src_item, self.eos_item])
Expand Down
8 changes: 4 additions & 4 deletions data/ofa_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def encode_text(self, text, length=None, append_bos=False, append_eos=False, use
s = torch.cat([s, self.eos_item])
return s

def pre_question(self, question, max_ques_words):
def pre_question(self, question, max_ques_words=None):
question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ')

question = re.sub(
Expand All @@ -55,12 +55,12 @@ def pre_question(self, question, max_ques_words):

# truncate question
question_words = question.split(' ')
if len(question_words) > max_ques_words:
if max_ques_words is not None and len(question_words) > max_ques_words:
question = ' '.join(question_words[:max_ques_words])

return question

def pre_caption(self, caption, max_words):
def pre_caption(self, caption, max_words=None):
caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('<person>', 'person')

caption = re.sub(
Expand All @@ -73,7 +73,7 @@ def pre_caption(self, caption, max_words):

# truncate caption
caption_words = caption.split(' ')
if len(caption_words) > max_words:
if max_words is not None and len(caption_words) > max_words:
caption = ' '.join(caption_words[:max_words])

return caption
19 changes: 18 additions & 1 deletion models/ofa/unify_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def add_args(parser):
help='freeze decoder token embedding')
parser.add_argument('--add-type-embedding', action='store_true',
help='add source/region/patch type embedding')
parser.add_argument('--interpolate-position', action='store_true',
help='interpolate position')

parser.add_argument('--resnet-type', choices=['resnet50', 'resnet101', 'resnet152'],
help='resnet type')
Expand Down Expand Up @@ -498,6 +500,9 @@ def __init__(self, args, dictionary, embed_tokens):
[Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.encoder_layers)]
)

self.patch_image_size = args.patch_image_size
self.orig_patch_image_size = args.orig_patch_image_size

self.register_buffer("token_rp_bucket", token_rp_bucket)
self.register_buffer("image_rp_bucket", image_rp_bucket)
self.entangle_position_embedding = args.entangle_position_embedding
Expand Down Expand Up @@ -560,7 +565,19 @@ def get_patch_images_info(self, patch_images, sample_patch_num, device):
image_num_patches = sample_patch_num
image_padding_mask = image_padding_mask.gather(1, patch_orders)
image_position_ids = image_position_ids.gather(1, patch_orders)
image_pos_embed = self.embed_image_positions(image_position_ids)
orig_num_patches = (self.orig_patch_image_size // 16) ** 2
orig_hw= self.orig_patch_image_size // 16
if getattr(self.args, "interpolate_position", False) and image_num_patches > orig_num_patches:
old_image_position_ids = torch.arange(orig_hw).unsqueeze(0).expand(orig_hw, orig_hw) + \
torch.arange(orig_hw).unsqueeze(1) * self.args.image_bucket_size + 1
old_image_position_ids = old_image_position_ids.to(device)
old_image_pos_embed = self.embed_image_positions(old_image_position_ids)
old_image_pos_embed = old_image_pos_embed.reshape(1, orig_hw, orig_hw, -1).permute(0, 3, 1, 2)
image_pos_embed = F.interpolate(old_image_pos_embed, size=(h, w), mode='bilinear')
image_pos_embed = image_pos_embed.permute(0, 2, 3, 1).reshape(1, image_num_patches, -1)
image_pos_embed = image_pos_embed.expand(patch_images.size(0), -1, -1)
else:
image_pos_embed = self.embed_image_positions(image_position_ids)

return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed

Expand Down
34 changes: 34 additions & 0 deletions run_scripts/refcoco/evaluate_refcoco_large_cn.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env bash

# The port for communication. Note that if you want to run multiple tasks on the same machine,
# you need to specify different port numbers.
export MASTER_PORT=6081
export CUDA_VISIBLE_DEVICES=7
export GPUS_PER_NODE=1

user_dir=../../ofa_module
bpe_dir=../../utils/BERT_CN_dict
selected_cols=0,3,1,2

data=../../dataset/refcoco_cn_data/refcoco+_test_sample.tsv
path=../../checkpoints/refcocoplus_cn_large.pt
result_path=../../results/refcoco
split='refcoco_val'
python3 ../../evaluate.py \
${data} \
--path=${path} \
--user-dir=${user_dir} \
--task=refcoco \
--batch-size=16 \
--log-format=simple --log-interval=10 \
--seed=7 \
--gen-subset=${split} \
--results-path=${result_path} \
--beam=5 \
--min-len=4 \
--max-len-a=0 \
--max-len-b=4 \
--no-repeat-ngram-size=3 \
--fp16 \
--num-workers=0 \
--model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"selected_cols\":\"${selected_cols}\"}"
30 changes: 23 additions & 7 deletions tasks/ofa_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class OFAConfig(FairseqDataclass):
default=None,
metadata={"help": "selected cols"},
)
bpe: Optional[str] = field(
default='gpt2',
metadata={"help": "which bpe to use"},
)
bpe_dir: Optional[str] = field(
default=None,
metadata={"help": "bpe dir"},
Expand All @@ -57,6 +61,9 @@ class OFAConfig(FairseqDataclass):
patch_image_size: int = field(
default=480, metadata={"help": "patch image size"}
)
orig_patch_image_size: int = field(
default=256, metadata={"help": "patch image size"}
)
num_bins: int = field(
default=1000, metadata={"help": "number of quantization bins"}
)
Expand Down Expand Up @@ -151,13 +158,22 @@ def get_batch_iterator(

def build_model(self, cfg: FairseqDataclass):
model = super().build_model(cfg)
bpe_dict = {
"_name": "gpt2",
"gpt2_encoder_json": os.path.join(self.cfg.bpe_dir, "encoder.json"),
"gpt2_vocab_bpe": os.path.join(self.cfg.bpe_dir, "vocab.bpe")
}
bpe_dict = DictConfig(bpe_dict)
self.bpe = self.build_bpe(bpe_dict)
if self.cfg.bpe == 'bert':
bpe_dict = {
"_name": "bert",
"bpe_vocab_file": os.path.join(self.cfg.bpe_dir, "vocab.txt"),
"bpe_cased": False
}
bpe_dict = DictConfig(bpe_dict)
self.bpe = self.build_bpe(bpe_dict)
else:
bpe_dict = {
"_name": "gpt2",
"gpt2_encoder_json": os.path.join(self.cfg.bpe_dir, "encoder.json"),
"gpt2_vocab_bpe": os.path.join(self.cfg.bpe_dir, "vocab.bpe")
}
bpe_dict = DictConfig(bpe_dict)
self.bpe = self.build_bpe(bpe_dict)
return model

def build_generator(
Expand Down
Loading

0 comments on commit b8c93fd

Please sign in to comment.