Skip to content

Commit

Permalink
merge origin/main
Browse files Browse the repository at this point in the history
  • Loading branch information
logicwong committed Mar 11, 2022
2 parents 2c1170d + bd2307a commit 6e19140
Show file tree
Hide file tree
Showing 51 changed files with 250 additions and 191 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ This source code is licensed under the Apache 2.0 license found in the LICENSE f
OFA is a unified multimodal pretrained model that unifies modalities (i.e., cross-modality, vision, language) and tasks
(e.g., image generation, visual grounding, image captioning, image classification, text generation, etc.)
to a simple sequence-to-sequence learning framework. For more information, please refer to our paper: [Unifying Architectures, Tasks, and Modalities Through a Simple Sequence-to-Sequence Learning Framework](http://arxiv.org/abs/2202.03052).

We welcome contributions to our project. Feel free to contact us or send us issues/PRs!
<br></br>


Expand All @@ -51,7 +53,7 @@ Also we provide Colab notebooks for you to better perceive the procedures. Click
<br></br>

# News
* 2022.3.08: Realease the pretrained checkpoint of OFA-Base in [checkpoints.md](checkpoints.md). To use OFA-Base, you just need to load `ofa_base.pt` and change `arch=ofa_large` to `arch=ofa_base` in the training scripts.
* 2022.3.08: Released the pretrained checkpoint of OFA-Base in [checkpoints.md](checkpoints.md). To use OFA-Base, you just need to load `ofa_base.pt` and change `--arch=ofa_large` to `--arch=ofa_base` in the training scripts.
* 2022.3.07: Released the finetuning & inference code/checkpoints for **Image Classification**, which achieves **85.0 accuracy on ImageNet-1K, slightly better than reported in OFA paper**.
* 2022.3.04: Released the finetuning & inference code/checkpoints for **Text-to-Image Generation**.
* 2022.3.03: Released the finetuning & inference code/checkpoints for **SNLI-VE** and **GLUE**.
Expand All @@ -64,9 +66,11 @@ Also we provide Colab notebooks for you to better perceive the procedures. Click

# TODO
* [x] To release finetuning and inference codes for multimodal downstream tasks soon, including image captioning, VQA, text-to-image generation, SNLI-VE, referring expression, comprehension, etc.
* [ ] To release finetuning and inference codes for unimodal downstream tasks soon.
* [x] To release finetuning and inference codes for unimodal downstream tasks soon.
* [ ] To release codes for pretraining soon.
* [ ] To integrate more downstream tasks concerning more modalities to our OFA framework.
* [ ] To release smaller models, including OFA-medium, OFA-tiny, as well as OFA-edge.
* [ ] To release OFA for Chinese.
<br></br>

# Approach
Expand Down Expand Up @@ -115,7 +119,7 @@ To release soon:)
<br></br>

# Finetuning & Inference
Below we provide methods for finetuning and inference on different downstream tasks. We provide both the pretrained checkpoint of OFA-Large and OFA-Base in [checkpoints.md](checkpoints.md), the following scripts is running for OFA-Large. If you want to use OFA-Base, just change the `restore_file` to the path where ofa_base.pt is located, and change `arch=ofa_large` to `arch=ofa_base`. **Note that the optimal hyperparameters for the Base model may be different from the Large model and requires proper hyperparameter tuning.
Below we provide methods for finetuning and inference on different downstream tasks. We provide both pretrained OFA-Large and OFA-Base in [checkpoints.md](checkpoints.md). The scripts in this section are prepared for OFA-Large. If you want to use OFA-Base, just modify `--restore-file` to the path where `ofa_base.pt` is located and change `--arch=ofa_large` to `--arch=ofa_base`. **Note that the optimal hyperparameters for the Base model may be different from the Large model and requires proper hyperparameter tuning.**
## Image Captioning
We provide procedures to reproduce our results of image captioning on our paper below.
<details>
Expand Down
8 changes: 4 additions & 4 deletions criterions/clip_scst_loss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import math
from dataclasses import dataclass, field
Expand Down
8 changes: 4 additions & 4 deletions criterions/label_smoothed_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import math
from dataclasses import dataclass, field
Expand Down
8 changes: 4 additions & 4 deletions criterions/scst_loss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import math
import string
Expand Down
11 changes: 6 additions & 5 deletions data/cv_data/image_classify_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

from io import BytesIO

import logging
Expand Down Expand Up @@ -192,4 +193,4 @@ def collater(self, samples, pad_to_length=None):
Returns:
dict: a mini-batch with the following keys:
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
10 changes: 5 additions & 5 deletions data/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

try:
from collections.abc import Iterable
Expand Down Expand Up @@ -598,4 +598,4 @@ def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
if ignored_paths:
advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
raise ValueError(msg)
raise ValueError(msg)
7 changes: 6 additions & 1 deletion data/file_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import os
import torch
import pickle
Expand Down Expand Up @@ -99,4 +104,4 @@ def __getitem__(self, index):
column_l = self._reader.readline().rstrip("\n").split(self.separator)
self.data_cnt += 1
column_l = [dtype(column_l[col_id]) for col_id, dtype in zip(self.selected_col_ids, self.dtypes)]
return column_l
return column_l
11 changes: 6 additions & 5 deletions data/mm_data/caption_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

from io import BytesIO

import logging
Expand Down Expand Up @@ -151,4 +152,4 @@ def collater(self, samples, pad_to_length=None):
Returns:
dict: a mini-batch with the following keys:
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
9 changes: 5 additions & 4 deletions data/mm_data/image_gen_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

from io import BytesIO

import logging
Expand Down
11 changes: 6 additions & 5 deletions data/mm_data/refcoco_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

from io import BytesIO

import logging
Expand Down Expand Up @@ -165,4 +166,4 @@ def collater(self, samples, pad_to_length=None):
Returns:
dict: a mini-batch with the following keys:
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
11 changes: 6 additions & 5 deletions data/mm_data/snli_ve_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

from io import BytesIO

import logging
Expand Down Expand Up @@ -199,4 +200,4 @@ def collater(self, samples, pad_to_length=None):
Returns:
dict: a mini-batch with the following keys:
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
9 changes: 5 additions & 4 deletions data/mm_data/vqa_gen_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

from io import BytesIO

import logging
Expand Down
10 changes: 5 additions & 5 deletions data/nlu_data/cola_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import logging
import warnings
Expand Down Expand Up @@ -135,4 +135,4 @@ def collater(self, samples, pad_to_length=None):
Returns:
dict: a mini-batch with the following keys:
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
10 changes: 5 additions & 5 deletions data/nlu_data/mnli_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import logging
import warnings
Expand Down Expand Up @@ -140,4 +140,4 @@ def collater(self, samples, pad_to_length=None):
Returns:
dict: a mini-batch with the following keys:
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
10 changes: 5 additions & 5 deletions data/nlu_data/mrpc_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import logging
import warnings
Expand Down Expand Up @@ -138,4 +138,4 @@ def collater(self, samples, pad_to_length=None):
Returns:
dict: a mini-batch with the following keys:
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
10 changes: 5 additions & 5 deletions data/nlu_data/qnli_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import logging
import warnings
Expand Down Expand Up @@ -138,4 +138,4 @@ def collater(self, samples, pad_to_length=None):
Returns:
dict: a mini-batch with the following keys:
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
10 changes: 5 additions & 5 deletions data/nlu_data/qqp_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import logging
import warnings
Expand Down Expand Up @@ -138,4 +138,4 @@ def collater(self, samples, pad_to_length=None):
Returns:
dict: a mini-batch with the following keys:
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
10 changes: 5 additions & 5 deletions data/nlu_data/rte_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import logging
import warnings
Expand Down Expand Up @@ -138,4 +138,4 @@ def collater(self, samples, pad_to_length=None):
Returns:
dict: a mini-batch with the following keys:
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
10 changes: 5 additions & 5 deletions data/nlu_data/sst2_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import logging
import warnings
Expand Down Expand Up @@ -135,4 +135,4 @@ def collater(self, samples, pad_to_length=None):
Returns:
dict: a mini-batch with the following keys:
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
5 changes: 5 additions & 0 deletions data/ofa_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import logging
import re
import torch.utils.data
Expand Down
8 changes: 4 additions & 4 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import logging
import os
Expand Down
9 changes: 5 additions & 4 deletions models/ofa/ofa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

"""
OFA
"""
Expand Down
8 changes: 4 additions & 4 deletions models/ofa/unify_multihead_attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import math
from typing import Dict, Optional, Tuple
Expand Down
Loading

0 comments on commit 6e19140

Please sign in to comment.