Skip to content

Commit

Permalink
fixed import bug & install fairseq without editable mode & add api
Browse files Browse the repository at this point in the history
  • Loading branch information
logicwong committed May 25, 2023
1 parent 3ecbffe commit e43583e
Show file tree
Hide file tree
Showing 49 changed files with 477 additions and 104 deletions.
2 changes: 1 addition & 1 deletion fairseq/fairseq/dataclass/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ class CommonConfig(FairseqDataclass):
},
)

# for one-piece
# for one-peace
layer_decay: float = field(
default=1.0,
metadata={
Expand Down
2 changes: 1 addition & 1 deletion one_peace/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import numpy as np
from utils.data_utils import collate_tokens
from ..utils.data_utils import collate_tokens


def collate_fn(samples, pad_idx, pad_to_length=None):
Expand Down
2 changes: 1 addition & 1 deletion one_peace/data/audio_data/aqa_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from data.base_dataset import BaseDataset
from ..base_dataset import BaseDataset


class AQADataset(BaseDataset):
Expand Down
2 changes: 1 addition & 1 deletion one_peace/data/audio_data/audio_classify_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from data.base_dataset import BaseDataset
from ..base_dataset import BaseDataset


class AudioClassifyDataset(BaseDataset):
Expand Down
2 changes: 1 addition & 1 deletion one_peace/data/audio_data/audio_text_retrieval_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from data.base_dataset import BaseDataset
from ..base_dataset import BaseDataset


class AudioTextRetrievalDataset(BaseDataset):
Expand Down
2 changes: 1 addition & 1 deletion one_peace/data/audio_data/vggsound_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from data.base_dataset import BaseDataset
from ..base_dataset import BaseDataset


class VggsoundDataset(BaseDataset):
Expand Down
2 changes: 1 addition & 1 deletion one_peace/data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from fairseq.data import FairseqDataset

from data import collate_fn
from . import collate_fn

logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
Expand Down
2 changes: 1 addition & 1 deletion one_peace/data/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fairseq.data.iterators import CountingIterator, BufferedIterator
from fairseq.data import data_utils

from utils.data_utils import new_islice
from ..utils.data_utils import new_islice

logger = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions one_peace/data/pretrain_data/audio_text_pretrain_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import torch

from data.base_dataset import BaseDataset
from utils.data_utils import get_whole_word_mask, compute_block_mask_1d
from ..base_dataset import BaseDataset
from ...utils.data_utils import get_whole_word_mask, compute_block_mask_1d


class AudioTextPretrainDataset(BaseDataset):
Expand Down
4 changes: 2 additions & 2 deletions one_peace/data/pretrain_data/image_text_pretrain_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from torchvision import transforms
from torchvision.transforms import InterpolationMode

from data.base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD
from utils.data_utils import get_whole_word_mask
from ..base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD
from ...utils.data_utils import get_whole_word_mask


class ImageTextPretrainDataset(BaseDataset):
Expand Down
8 changes: 4 additions & 4 deletions one_peace/data/vision_data/image_classify_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from timm.data import create_transform
from timm.data.mixup import Mixup

from data import collate_fn
from data.base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD
from utils.randaugment import RandomAugment
import utils.transforms as utils_transforms
from .. import collate_fn
from ..base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD
from ...utils.randaugment import RandomAugment
from ...utils import transforms as utils_transforms


class ImageClassifyDataset(BaseDataset):
Expand Down
2 changes: 1 addition & 1 deletion one_peace/data/vl_data/image_text_retrieval_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision import transforms
from torchvision.transforms import InterpolationMode

from data.base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD
from ..base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD


class ImageTextRetrievalDataset(BaseDataset):
Expand Down
6 changes: 3 additions & 3 deletions one_peace/data/vl_data/nlvr2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from torchvision import transforms
from torchvision.transforms import InterpolationMode

from data.base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD
from utils.randaugment import RandomAugment
import utils.transforms as utils_transforms
from ..base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD
from ...utils.randaugment import RandomAugment
from ...utils import transforms as utils_transforms


class Nlvr2Dataset(BaseDataset):
Expand Down
4 changes: 2 additions & 2 deletions one_peace/data/vl_data/refcoco_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import numpy as np
import torch

from data.base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD
import utils.transforms as T
from ..base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD
from ...utils import transforms as T


class RefCOCODataset(BaseDataset):
Expand Down
2 changes: 1 addition & 1 deletion one_peace/data/vl_data/vqa_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision import transforms
from torchvision.transforms import InterpolationMode

from data.base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD
from ..base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD


class VqaDataset(BaseDataset):
Expand Down
1 change: 1 addition & 0 deletions one_peace/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from fairseq.dataclass.initialize import add_defaults
from omegaconf import DictConfig

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
Expand Down
2 changes: 1 addition & 1 deletion one_peace/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.distributed as dist

from .base_metric import BaseMetric
from utils.data_utils import all_gather
from ..utils.data_utils import all_gather


class Accuracy(BaseMetric):
Expand Down
2 changes: 1 addition & 1 deletion one_peace/metrics/iou_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.distributed as dist

from .base_metric import BaseMetric
from utils.data_utils import all_gather
from ..utils.data_utils import all_gather


class IouAcc(BaseMetric):
Expand Down
2 changes: 1 addition & 1 deletion one_peace/metrics/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sklearn.metrics import average_precision_score

from .base_metric import BaseMetric
from utils.data_utils import all_gather
from ..utils.data_utils import all_gather


class MAP(BaseMetric):
Expand Down
2 changes: 1 addition & 1 deletion one_peace/metrics/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.distributed as dist

from .base_metric import BaseMetric
from utils.data_utils import all_gather
from ..utils.data_utils import all_gather


class Recall(BaseMetric):
Expand Down
3 changes: 2 additions & 1 deletion one_peace/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .one_peace.one_peace_base import OnePeaceBaseModel
from .one_peace.one_peace_classify import OnePeaceClassifyModel
from .one_peace.one_peace_pretrain import OnePeacePretrainModel
from .one_peace.one_peace_retrieval import OnePeaceRetrievalModel
from .one_peace.one_peace_retrieval import OnePeaceRetrievalModel
from .one_peace.hub_interface import from_pretrained
2 changes: 1 addition & 1 deletion one_peace/models/adapter/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from fairseq.modules import FairseqDropout
from fairseq import utils

from models.components import Embedding, trunc_normal_, LayerNorm, Linear
from ..components import Embedding, trunc_normal_, LayerNorm, Linear

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion one_peace/models/adapter/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.nn.functional as F
from fairseq.modules import FairseqDropout

from models.components import Embedding, trunc_normal_, LayerNorm
from ..components import Embedding, trunc_normal_, LayerNorm

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion one_peace/models/adapter/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fairseq.modules import FairseqDropout
from fairseq import utils

from models.components import Embedding, trunc_normal_, LayerNorm
from ..components import Embedding, trunc_normal_, LayerNorm

logger = logging.getLogger(__name__)

Expand Down
Loading

0 comments on commit e43583e

Please sign in to comment.