From e43583e16d6e4440ef371b97e6112428675cd645 Mon Sep 17 00:00:00 2001 From: logicwong <798960736@qq.com> Date: Thu, 25 May 2023 21:41:52 +0800 Subject: [PATCH] fixed import bug & install fairseq without editable mode & add api --- fairseq/fairseq/dataclass/configs.py | 2 +- one_peace/data/__init__.py | 2 +- one_peace/data/audio_data/aqa_dataset.py | 2 +- .../data/audio_data/audio_classify_dataset.py | 2 +- .../audio_text_retrieval_dataset.py | 2 +- one_peace/data/audio_data/vggsound_dataset.py | 2 +- one_peace/data/base_dataset.py | 2 +- one_peace/data/iterators.py | 2 +- .../audio_text_pretrain_dataset.py | 4 +- .../image_text_pretrain_dataset.py | 4 +- .../vision_data/image_classify_dataset.py | 8 +- .../vl_data/image_text_retrieval_dataset.py | 2 +- one_peace/data/vl_data/nlvr2_dataset.py | 6 +- one_peace/data/vl_data/refcoco_dataset.py | 4 +- one_peace/data/vl_data/vqa_dataset.py | 2 +- one_peace/evaluate.py | 1 + one_peace/metrics/accuracy.py | 2 +- one_peace/metrics/iou_acc.py | 2 +- one_peace/metrics/map.py | 2 +- one_peace/metrics/recall.py | 2 +- one_peace/models/__init__.py | 3 +- one_peace/models/adapter/audio.py | 2 +- one_peace/models/adapter/image.py | 2 +- one_peace/models/adapter/text.py | 2 +- one_peace/models/one_peace/hub_interface.py | 184 ++++++++++++++++++ one_peace/models/one_peace/one_peace_base.py | 16 +- .../models/one_peace/one_peace_classify.py | 6 +- .../models/one_peace/one_peace_pretrain.py | 8 +- .../models/one_peace/one_peace_retrieval.py | 8 +- .../models/transformer/multihead_attention.py | 2 +- .../models/transformer/transformer_encoder.py | 2 +- .../models/transformer/transformer_layer.py | 2 +- one_peace/tasks/audio_tasks/aqa.py | 6 +- one_peace/tasks/audio_tasks/audio_classify.py | 6 +- .../tasks/audio_tasks/audio_text_retrieval.py | 8 +- one_peace/tasks/audio_tasks/vggsound.py | 6 +- one_peace/tasks/base_task.py | 4 +- .../pretrain_tasks/audio_text_pretrain.py | 6 +- .../pretrain_tasks/image_text_pretrain.py | 8 +- .../tasks/vision_tasks/image_classify.py | 6 +- .../tasks/vl_tasks/image_text_retrieval.py | 8 +- one_peace/tasks/vl_tasks/nlvr2.py | 6 +- one_peace/tasks/vl_tasks/refcoco.py | 6 +- one_peace/tasks/vl_tasks/vqa.py | 6 +- one_peace/train.py | 4 +- one_peace/trainer.py | 6 +- one_peace/user_module/__init__.py | 14 +- one_peace/utils/hub_interface.py | 184 ++++++++++++++++++ requirements.txt | 5 +- 49 files changed, 477 insertions(+), 104 deletions(-) create mode 100644 one_peace/models/one_peace/hub_interface.py create mode 100644 one_peace/utils/hub_interface.py diff --git a/fairseq/fairseq/dataclass/configs.py b/fairseq/fairseq/dataclass/configs.py index 090a04a..86fe7c7 100644 --- a/fairseq/fairseq/dataclass/configs.py +++ b/fairseq/fairseq/dataclass/configs.py @@ -251,7 +251,7 @@ class CommonConfig(FairseqDataclass): }, ) - # for one-piece + # for one-peace layer_decay: float = field( default=1.0, metadata={ diff --git a/one_peace/data/__init__.py b/one_peace/data/__init__.py index 96a072e..33e3c17 100644 --- a/one_peace/data/__init__.py +++ b/one_peace/data/__init__.py @@ -1,6 +1,6 @@ import torch import numpy as np -from utils.data_utils import collate_tokens +from ..utils.data_utils import collate_tokens def collate_fn(samples, pad_idx, pad_to_length=None): diff --git a/one_peace/data/audio_data/aqa_dataset.py b/one_peace/data/audio_data/aqa_dataset.py index bcc3486..48ce992 100644 --- a/one_peace/data/audio_data/aqa_dataset.py +++ b/one_peace/data/audio_data/aqa_dataset.py @@ -5,7 +5,7 @@ import torch -from data.base_dataset import BaseDataset +from ..base_dataset import BaseDataset class AQADataset(BaseDataset): diff --git a/one_peace/data/audio_data/audio_classify_dataset.py b/one_peace/data/audio_data/audio_classify_dataset.py index a41fe37..ccdea66 100644 --- a/one_peace/data/audio_data/audio_classify_dataset.py +++ b/one_peace/data/audio_data/audio_classify_dataset.py @@ -5,7 +5,7 @@ import torch -from data.base_dataset import BaseDataset +from ..base_dataset import BaseDataset class AudioClassifyDataset(BaseDataset): diff --git a/one_peace/data/audio_data/audio_text_retrieval_dataset.py b/one_peace/data/audio_data/audio_text_retrieval_dataset.py index d2381b5..1df731f 100644 --- a/one_peace/data/audio_data/audio_text_retrieval_dataset.py +++ b/one_peace/data/audio_data/audio_text_retrieval_dataset.py @@ -5,7 +5,7 @@ import torch -from data.base_dataset import BaseDataset +from ..base_dataset import BaseDataset class AudioTextRetrievalDataset(BaseDataset): diff --git a/one_peace/data/audio_data/vggsound_dataset.py b/one_peace/data/audio_data/vggsound_dataset.py index 3e6ca84..aef24ab 100644 --- a/one_peace/data/audio_data/vggsound_dataset.py +++ b/one_peace/data/audio_data/vggsound_dataset.py @@ -5,7 +5,7 @@ import torch -from data.base_dataset import BaseDataset +from ..base_dataset import BaseDataset class VggsoundDataset(BaseDataset): diff --git a/one_peace/data/base_dataset.py b/one_peace/data/base_dataset.py index 842426f..f936144 100644 --- a/one_peace/data/base_dataset.py +++ b/one_peace/data/base_dataset.py @@ -15,7 +15,7 @@ from fairseq.data import FairseqDataset -from data import collate_fn +from . import collate_fn logger = logging.getLogger(__name__) warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) diff --git a/one_peace/data/iterators.py b/one_peace/data/iterators.py index e910fad..a153dd1 100644 --- a/one_peace/data/iterators.py +++ b/one_peace/data/iterators.py @@ -9,7 +9,7 @@ from fairseq.data.iterators import CountingIterator, BufferedIterator from fairseq.data import data_utils -from utils.data_utils import new_islice +from ..utils.data_utils import new_islice logger = logging.getLogger(__name__) diff --git a/one_peace/data/pretrain_data/audio_text_pretrain_dataset.py b/one_peace/data/pretrain_data/audio_text_pretrain_dataset.py index e171579..4320557 100644 --- a/one_peace/data/pretrain_data/audio_text_pretrain_dataset.py +++ b/one_peace/data/pretrain_data/audio_text_pretrain_dataset.py @@ -7,8 +7,8 @@ import torch -from data.base_dataset import BaseDataset -from utils.data_utils import get_whole_word_mask, compute_block_mask_1d +from ..base_dataset import BaseDataset +from ...utils.data_utils import get_whole_word_mask, compute_block_mask_1d class AudioTextPretrainDataset(BaseDataset): diff --git a/one_peace/data/pretrain_data/image_text_pretrain_dataset.py b/one_peace/data/pretrain_data/image_text_pretrain_dataset.py index 18d2267..1b9a54f 100644 --- a/one_peace/data/pretrain_data/image_text_pretrain_dataset.py +++ b/one_peace/data/pretrain_data/image_text_pretrain_dataset.py @@ -9,8 +9,8 @@ from torchvision import transforms from torchvision.transforms import InterpolationMode -from data.base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD -from utils.data_utils import get_whole_word_mask +from ..base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD +from ...utils.data_utils import get_whole_word_mask class ImageTextPretrainDataset(BaseDataset): diff --git a/one_peace/data/vision_data/image_classify_dataset.py b/one_peace/data/vision_data/image_classify_dataset.py index 697f208..496315b 100644 --- a/one_peace/data/vision_data/image_classify_dataset.py +++ b/one_peace/data/vision_data/image_classify_dataset.py @@ -9,10 +9,10 @@ from timm.data import create_transform from timm.data.mixup import Mixup -from data import collate_fn -from data.base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD -from utils.randaugment import RandomAugment -import utils.transforms as utils_transforms +from .. import collate_fn +from ..base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD +from ...utils.randaugment import RandomAugment +from ...utils import transforms as utils_transforms class ImageClassifyDataset(BaseDataset): diff --git a/one_peace/data/vl_data/image_text_retrieval_dataset.py b/one_peace/data/vl_data/image_text_retrieval_dataset.py index f7b05d2..a0079f2 100644 --- a/one_peace/data/vl_data/image_text_retrieval_dataset.py +++ b/one_peace/data/vl_data/image_text_retrieval_dataset.py @@ -7,7 +7,7 @@ from torchvision import transforms from torchvision.transforms import InterpolationMode -from data.base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD +from ..base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD class ImageTextRetrievalDataset(BaseDataset): diff --git a/one_peace/data/vl_data/nlvr2_dataset.py b/one_peace/data/vl_data/nlvr2_dataset.py index d119f8d..49e2697 100644 --- a/one_peace/data/vl_data/nlvr2_dataset.py +++ b/one_peace/data/vl_data/nlvr2_dataset.py @@ -7,9 +7,9 @@ from torchvision import transforms from torchvision.transforms import InterpolationMode -from data.base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD -from utils.randaugment import RandomAugment -import utils.transforms as utils_transforms +from ..base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD +from ...utils.randaugment import RandomAugment +from ...utils import transforms as utils_transforms class Nlvr2Dataset(BaseDataset): diff --git a/one_peace/data/vl_data/refcoco_dataset.py b/one_peace/data/vl_data/refcoco_dataset.py index 4b47501..0d7acdf 100644 --- a/one_peace/data/vl_data/refcoco_dataset.py +++ b/one_peace/data/vl_data/refcoco_dataset.py @@ -6,8 +6,8 @@ import numpy as np import torch -from data.base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD -import utils.transforms as T +from ..base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD +from ...utils import transforms as T class RefCOCODataset(BaseDataset): diff --git a/one_peace/data/vl_data/vqa_dataset.py b/one_peace/data/vl_data/vqa_dataset.py index 347d524..751c174 100644 --- a/one_peace/data/vl_data/vqa_dataset.py +++ b/one_peace/data/vl_data/vqa_dataset.py @@ -7,7 +7,7 @@ from torchvision import transforms from torchvision.transforms import InterpolationMode -from data.base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD +from ..base_dataset import BaseDataset, CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD class VqaDataset(BaseDataset): diff --git a/one_peace/evaluate.py b/one_peace/evaluate.py index 276f13d..4000f92 100644 --- a/one_peace/evaluate.py +++ b/one_peace/evaluate.py @@ -21,6 +21,7 @@ from fairseq.dataclass.initialize import add_defaults from omegaconf import DictConfig +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", diff --git a/one_peace/metrics/accuracy.py b/one_peace/metrics/accuracy.py index 205cee7..b2075a6 100644 --- a/one_peace/metrics/accuracy.py +++ b/one_peace/metrics/accuracy.py @@ -2,7 +2,7 @@ import torch.distributed as dist from .base_metric import BaseMetric -from utils.data_utils import all_gather +from ..utils.data_utils import all_gather class Accuracy(BaseMetric): diff --git a/one_peace/metrics/iou_acc.py b/one_peace/metrics/iou_acc.py index 098e1db..f953af8 100644 --- a/one_peace/metrics/iou_acc.py +++ b/one_peace/metrics/iou_acc.py @@ -2,7 +2,7 @@ import torch.distributed as dist from .base_metric import BaseMetric -from utils.data_utils import all_gather +from ..utils.data_utils import all_gather class IouAcc(BaseMetric): diff --git a/one_peace/metrics/map.py b/one_peace/metrics/map.py index a6c23c1..0c7b6d4 100644 --- a/one_peace/metrics/map.py +++ b/one_peace/metrics/map.py @@ -5,7 +5,7 @@ from sklearn.metrics import average_precision_score from .base_metric import BaseMetric -from utils.data_utils import all_gather +from ..utils.data_utils import all_gather class MAP(BaseMetric): diff --git a/one_peace/metrics/recall.py b/one_peace/metrics/recall.py index 1657e8a..499ead4 100644 --- a/one_peace/metrics/recall.py +++ b/one_peace/metrics/recall.py @@ -2,7 +2,7 @@ import torch.distributed as dist from .base_metric import BaseMetric -from utils.data_utils import all_gather +from ..utils.data_utils import all_gather class Recall(BaseMetric): diff --git a/one_peace/models/__init__.py b/one_peace/models/__init__.py index cdbf76e..6528ee0 100644 --- a/one_peace/models/__init__.py +++ b/one_peace/models/__init__.py @@ -1,4 +1,5 @@ from .one_peace.one_peace_base import OnePeaceBaseModel from .one_peace.one_peace_classify import OnePeaceClassifyModel from .one_peace.one_peace_pretrain import OnePeacePretrainModel -from .one_peace.one_peace_retrieval import OnePeaceRetrievalModel \ No newline at end of file +from .one_peace.one_peace_retrieval import OnePeaceRetrievalModel +from .one_peace.hub_interface import from_pretrained \ No newline at end of file diff --git a/one_peace/models/adapter/audio.py b/one_peace/models/adapter/audio.py index d34f52c..bda650e 100644 --- a/one_peace/models/adapter/audio.py +++ b/one_peace/models/adapter/audio.py @@ -12,7 +12,7 @@ from fairseq.modules import FairseqDropout from fairseq import utils -from models.components import Embedding, trunc_normal_, LayerNorm, Linear +from ..components import Embedding, trunc_normal_, LayerNorm, Linear logger = logging.getLogger(__name__) diff --git a/one_peace/models/adapter/image.py b/one_peace/models/adapter/image.py index 8aca51c..7c36f9a 100644 --- a/one_peace/models/adapter/image.py +++ b/one_peace/models/adapter/image.py @@ -11,7 +11,7 @@ import torch.nn.functional as F from fairseq.modules import FairseqDropout -from models.components import Embedding, trunc_normal_, LayerNorm +from ..components import Embedding, trunc_normal_, LayerNorm logger = logging.getLogger(__name__) diff --git a/one_peace/models/adapter/text.py b/one_peace/models/adapter/text.py index 8daf393..c22328f 100644 --- a/one_peace/models/adapter/text.py +++ b/one_peace/models/adapter/text.py @@ -10,7 +10,7 @@ from fairseq.modules import FairseqDropout from fairseq import utils -from models.components import Embedding, trunc_normal_, LayerNorm +from ..components import Embedding, trunc_normal_, LayerNorm logger = logging.getLogger(__name__) diff --git a/one_peace/models/one_peace/hub_interface.py b/one_peace/models/one_peace/hub_interface.py new file mode 100644 index 0000000..af071b9 --- /dev/null +++ b/one_peace/models/one_peace/hub_interface.py @@ -0,0 +1,184 @@ + +import os +import urllib +import math +import librosa +from tqdm import tqdm +from PIL import Image + +import torch +import torch.nn.functional as F +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from fairseq import checkpoint_utils, utils + +from ...data.base_dataset import CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD +from ...utils.data_utils import collate_tokens +from ... import tasks +from ... import models + +_MODELS = { + "ONE-PEACE": "http://one-peace-shanghai.oss-cn-shanghai.aliyuncs.com/one-peace.pt" +} + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, + unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + return download_target + + +def from_pretrained( + model_name_or_path, + device=("cuda" if torch.cuda.is_available() else "cpu"), + dtype="float32", + download_root=None +): + + if os.path.isfile(model_name_or_path): + model_path = model_name_or_path + else: + model_path = _download(_MODELS[model_name_or_path], download_root or os.path.expanduser("~/.cache/one-peace")) + + # utils.import_user_module(argparse.Namespace(user_dir='../../user_module')) + overrides = {'model':{'_name':'one_peace_retrieval'}} + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( + [model_path], + arg_overrides=overrides + ) + model = models[0] + + return OnePeaceHubInterface(saved_cfg, task, model, device, dtype) + + +class OnePeaceHubInterface: + """A simple PyTorch Hub interface to ONE-PEACE.""" + + def __init__(self, cfg, task, model, device="cpu", dtype="float32"): + super().__init__() + self.model = model + self.device = device + self.dtype = dtype + + # for text + self.dict = task.dict + self.bpe = task.bpe + self.eos = self.dict.eos() + self.pad = self.dict.pad() + # for image + mean = CLIP_DEFAULT_MEAN + std = CLIP_DEFAULT_STD + self.transform = transforms.Compose([ + transforms.Resize((256, 256), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std) + ]) + # for audio + feature_encoder_spec = '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]' + self.feature_encoder_spec = eval(feature_encoder_spec) + self._features_size_map = {} + + self.model.to(device) + self.model.eval() + if self.dtype == "bf16": + self.model.bfloat16() + elif self.dtype == "fp16": + self.model.half() + else: + self.model.float() + + def cast_data_dtype(self, t): + if self.dtype == "bf16": + return t.to(dtype=torch.bfloat16) + elif self.dtype == "fp16": + return t.to(dtype=torch.half) + else: + return t + + def _get_mask_indices_dims(self, size, feature_encoder_spec, padding=0, dilation=1): + if size not in self._features_size_map: + L_in = size + for (_, kernel_size, stride) in feature_encoder_spec: + L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 + L_out = 1 + L_out // stride + L_in = L_out + self._features_size_map[size] = L_out + return self._features_size_map[size] + + def process_text(self, text_list): + tokens_list = [] + for text in text_list: + s = self.dict.encode_line( + line=self.bpe.encode(text), + add_if_not_exist=False, + append_eos=False + ).long() + s = s[:70] + s = torch.cat([s, torch.LongTensor([self.eos])]) + tokens_list.append(s) + src_tokens = collate_tokens(tokens_list, pad_idx=self.pad).to(self.device) + src_tokens = self.cast_data_dtype(src_tokens) + return src_tokens + + def process_image(self, image_list): + patch_images_list = [] + for image_path in image_list: + image = Image.open(image_path).convert("RGB") + patch_image = self.transform(image) + patch_images_list.append(patch_image) + src_images = torch.stack(patch_images_list, dim=0).to(self.device) + src_images = self.cast_data_dtype(src_images) + return src_images + + def process_audio(self, audio_list): + feats_list = [] + audio_padding_mask_list = [] + for audio in audio_list: + wav, curr_sample_rate = librosa.load(audio, sr=16000) + assert curr_sample_rate == 16000 + feats = torch.tensor(wav) + with torch.no_grad(): + feats = F.layer_norm(feats, feats.shape) + if feats.size(-1) > curr_sample_rate * 15: + start_idx = 0 + end_idx = start_idx + curr_sample_rate * 15 + feats = feats[start_idx:end_idx] + if feats.size(-1) < curr_sample_rate * 1: + feats = feats.repeat(math.ceil(curr_sample_rate * 1 / feats.size(-1))) + feats = feats[:curr_sample_rate * 1] + T = self._get_mask_indices_dims(feats.size(-1), self.feature_encoder_spec) + audio_padding_mask = torch.zeros(T + 1).bool() + feats_list.append(feats) + audio_padding_mask_list.append(audio_padding_mask) + src_audios = collate_tokens(feats_list, pad_idx=0).to(self.device) + src_audios = self.cast_data_dtype(src_audios) + audio_padding_masks = collate_tokens(audio_padding_mask_list, pad_idx=True).to(self.device) + return src_audios, audio_padding_masks + + def extract_text_features(self, src_tokens): + return self.model(src_tokens=src_tokens, encoder_type="text") + + def extract_image_features(self, src_images): + return self.model(src_images=src_images, encoder_type="image") + + def extract_audio_features(self, src_audios, audio_padding_masks): + return self.model(src_audios=src_audios, audio_padding_masks=audio_padding_masks, encoder_type="audio") diff --git a/one_peace/models/one_peace/one_peace_base.py b/one_peace/models/one_peace/one_peace_base.py index 843f31d..d2d63f1 100644 --- a/one_peace/models/one_peace/one_peace_base.py +++ b/one_peace/models/one_peace/one_peace_base.py @@ -4,7 +4,7 @@ # found in the LICENSE file in the root directory. """ -One-Piece Base Model Wrapper +ONE-PEACE Base Model Wrapper """ import logging @@ -16,13 +16,13 @@ from fairseq.models import register_model, BaseFairseqModel from fairseq import utils -from models.unify_model_config import UnifyModelConfig -from models.components import trunc_normal_ -from models.adapter.text import TextAdapter -from models.adapter.image import ImageAdapter -from models.adapter.audio import AudioAdapter -from models.transformer.transformer_encoder import TransformerEncoder -from models.components import Linear, LayerNorm +from ..unify_model_config import UnifyModelConfig +from ..components import trunc_normal_ +from ..adapter.text import TextAdapter +from ..adapter.image import ImageAdapter +from ..adapter.audio import AudioAdapter +from ..transformer.transformer_encoder import TransformerEncoder +from ..components import Linear, LayerNorm logger = logging.getLogger(__name__) diff --git a/one_peace/models/one_peace/one_peace_classify.py b/one_peace/models/one_peace/one_peace_classify.py index d5e9a19..05384b9 100644 --- a/one_peace/models/one_peace/one_peace_classify.py +++ b/one_peace/models/one_peace/one_peace_classify.py @@ -4,7 +4,7 @@ # found in the LICENSE file in the root directory. """ -One-Piece Classify +ONE-PEACE Classify """ from typing import Optional from dataclasses import dataclass, field @@ -17,8 +17,8 @@ from fairseq.distributed import fsdp_wrap from fairseq.modules.checkpoint_activations import checkpoint_wrapper -from models.unify_model_config import UnifyModelConfig -from models.one_peace.one_peace_base import ModelWrapper, OnePeaceClassifyHead, OnePeaceBaseModel, init_one_peace_params +from ..unify_model_config import UnifyModelConfig +from .one_peace_base import ModelWrapper, OnePeaceClassifyHead, OnePeaceBaseModel, init_one_peace_params logger = logging.getLogger(__name__) diff --git a/one_peace/models/one_peace/one_peace_pretrain.py b/one_peace/models/one_peace/one_peace_pretrain.py index 9ad4f6d..a91b8ba 100644 --- a/one_peace/models/one_peace/one_peace_pretrain.py +++ b/one_peace/models/one_peace/one_peace_pretrain.py @@ -4,7 +4,7 @@ # found in the LICENSE file in the root directory. """ -One-Piece Pretrain +ONE-PEACE Pretrain """ from typing import Optional from dataclasses import dataclass @@ -19,9 +19,9 @@ from fairseq.distributed import fsdp_wrap from fairseq.modules.checkpoint_activations import checkpoint_wrapper -from models.unify_model_config import UnifyModelConfig -from models.one_peace.one_peace_base import ModelWrapper, OnePeaceBaseModel, init_one_peace_params -from models.components import Linear, trunc_normal_ +from ..unify_model_config import UnifyModelConfig +from .one_peace_base import ModelWrapper, OnePeaceBaseModel, init_one_peace_params +from ..components import Linear, trunc_normal_ logger = logging.getLogger(__name__) diff --git a/one_peace/models/one_peace/one_peace_retrieval.py b/one_peace/models/one_peace/one_peace_retrieval.py index 4de2033..f5b98b5 100644 --- a/one_peace/models/one_peace/one_peace_retrieval.py +++ b/one_peace/models/one_peace/one_peace_retrieval.py @@ -4,7 +4,7 @@ # found in the LICENSE file in the root directory. """ -One-Piece Retrieval +ONE-PEACE Retrieval """ from typing import Optional from dataclasses import dataclass @@ -19,9 +19,9 @@ from fairseq.distributed import fsdp_wrap from fairseq.modules.checkpoint_activations import checkpoint_wrapper -from models.unify_model_config import UnifyModelConfig -from models.components import Linear -from models.one_peace.one_peace_base import ModelWrapper, OnePeaceBaseModel, init_one_peace_params +from ..unify_model_config import UnifyModelConfig +from ..components import Linear +from .one_peace_base import ModelWrapper, OnePeaceBaseModel, init_one_peace_params logger = logging.getLogger(__name__) diff --git a/one_peace/models/transformer/multihead_attention.py b/one_peace/models/transformer/multihead_attention.py index f87b53f..fb864fb 100644 --- a/one_peace/models/transformer/multihead_attention.py +++ b/one_peace/models/transformer/multihead_attention.py @@ -12,7 +12,7 @@ from fairseq import utils from fairseq.modules.fairseq_dropout import FairseqDropout -from models.components import Linear, LayerNorm +from ..components import Linear, LayerNorm logger = logging.getLogger(__name__) diff --git a/one_peace/models/transformer/transformer_encoder.py b/one_peace/models/transformer/transformer_encoder.py index 75a6212..bdee010 100644 --- a/one_peace/models/transformer/transformer_encoder.py +++ b/one_peace/models/transformer/transformer_encoder.py @@ -14,7 +14,7 @@ LayerDropModuleList ) -from models.components import LayerNorm +from ..components import LayerNorm from .transformer_layer import TransformerEncoderLayer logger = logging.getLogger(__name__) diff --git a/one_peace/models/transformer/transformer_layer.py b/one_peace/models/transformer/transformer_layer.py index d895aa7..8397dc2 100644 --- a/one_peace/models/transformer/transformer_layer.py +++ b/one_peace/models/transformer/transformer_layer.py @@ -12,7 +12,7 @@ from fairseq.modules.fairseq_dropout import FairseqDropout -from models.components import Linear, LayerNorm +from ..components import Linear, LayerNorm from .multihead_attention import MultiheadAttention logger = logging.getLogger(__name__) diff --git a/one_peace/tasks/audio_tasks/aqa.py b/one_peace/tasks/audio_tasks/aqa.py index 339889e..c4163cd 100644 --- a/one_peace/tasks/audio_tasks/aqa.py +++ b/one_peace/tasks/audio_tasks/aqa.py @@ -9,9 +9,9 @@ from fairseq.tasks import register_task -from tasks.base_task import BaseTask, BaseTaskConfig -from data.audio_data.aqa_dataset import AQADataset -from metrics import Accuracy +from ..base_task import BaseTask, BaseTaskConfig +from ...data.audio_data.aqa_dataset import AQADataset +from ...metrics import Accuracy logger = logging.getLogger(__name__) diff --git a/one_peace/tasks/audio_tasks/audio_classify.py b/one_peace/tasks/audio_tasks/audio_classify.py index 9bff7ef..d823be4 100644 --- a/one_peace/tasks/audio_tasks/audio_classify.py +++ b/one_peace/tasks/audio_tasks/audio_classify.py @@ -9,9 +9,9 @@ from fairseq.tasks import register_task -from tasks.base_task import BaseTask, BaseTaskConfig -from data.audio_data.audio_classify_dataset import AudioClassifyDataset -from metrics import MAP +from ..base_task import BaseTask, BaseTaskConfig +from ...data.audio_data.audio_classify_dataset import AudioClassifyDataset +from ...metrics import MAP logger = logging.getLogger(__name__) diff --git a/one_peace/tasks/audio_tasks/audio_text_retrieval.py b/one_peace/tasks/audio_tasks/audio_text_retrieval.py index 3793097..c3f4885 100644 --- a/one_peace/tasks/audio_tasks/audio_text_retrieval.py +++ b/one_peace/tasks/audio_tasks/audio_text_retrieval.py @@ -13,10 +13,10 @@ from fairseq.tasks import register_task from fairseq.utils import move_to_cuda -from tasks.base_task import BaseTask, BaseTaskConfig -from data.audio_data.audio_text_retrieval_dataset import AudioTextRetrievalDataset -from utils.data_utils import new_islice, all_gather -from metrics import Recall +from ..base_task import BaseTask, BaseTaskConfig +from ...data.audio_data.audio_text_retrieval_dataset import AudioTextRetrievalDataset +from ...utils.data_utils import new_islice, all_gather +from ...metrics import Recall logger = logging.getLogger(__name__) diff --git a/one_peace/tasks/audio_tasks/vggsound.py b/one_peace/tasks/audio_tasks/vggsound.py index 549ac8d..e16e341 100644 --- a/one_peace/tasks/audio_tasks/vggsound.py +++ b/one_peace/tasks/audio_tasks/vggsound.py @@ -9,9 +9,9 @@ from fairseq.tasks import register_task -from tasks.base_task import BaseTask, BaseTaskConfig -from data.audio_data.vggsound_dataset import VggsoundDataset -from metrics import Accuracy +from ..base_task import BaseTask, BaseTaskConfig +from ...data.audio_data.vggsound_dataset import VggsoundDataset +from ...metrics import Accuracy logger = logging.getLogger(__name__) diff --git a/one_peace/tasks/base_task.py b/one_peace/tasks/base_task.py index 1e2527e..41f6ba8 100644 --- a/one_peace/tasks/base_task.py +++ b/one_peace/tasks/base_task.py @@ -18,8 +18,8 @@ from fairseq.tasks import FairseqTask, register_task from omegaconf import DictConfig -from data.tsv_reader import TSVReader -from data.iterators import EpochBatchIterator +from ..data.tsv_reader import TSVReader +from ..data.iterators import EpochBatchIterator logger = logging.getLogger(__name__) diff --git a/one_peace/tasks/pretrain_tasks/audio_text_pretrain.py b/one_peace/tasks/pretrain_tasks/audio_text_pretrain.py index 03f0768..72c4511 100644 --- a/one_peace/tasks/pretrain_tasks/audio_text_pretrain.py +++ b/one_peace/tasks/pretrain_tasks/audio_text_pretrain.py @@ -12,9 +12,9 @@ from fairseq.tasks import register_task from fairseq.utils import move_to_cuda -from tasks.base_task import BaseTask, BaseTaskConfig -from data.pretrain_data.audio_text_pretrain_dataset import AudioTextPretrainDataset -from metrics import Recall +from ..base_task import BaseTask, BaseTaskConfig +from ...data.pretrain_data.audio_text_pretrain_dataset import AudioTextPretrainDataset +from ...metrics import Recall logger = logging.getLogger(__name__) diff --git a/one_peace/tasks/pretrain_tasks/image_text_pretrain.py b/one_peace/tasks/pretrain_tasks/image_text_pretrain.py index fce0a05..13b51e1 100644 --- a/one_peace/tasks/pretrain_tasks/image_text_pretrain.py +++ b/one_peace/tasks/pretrain_tasks/image_text_pretrain.py @@ -13,10 +13,10 @@ from fairseq.tasks import register_task from fairseq.utils import move_to_cuda -from tasks.base_task import BaseTask, BaseTaskConfig -from data.pretrain_data.image_text_pretrain_dataset import ImageTextPretrainDataset -from utils.data_utils import new_islice, all_gather -from metrics import Recall +from ..base_task import BaseTask, BaseTaskConfig +from ...data.pretrain_data.image_text_pretrain_dataset import ImageTextPretrainDataset +from ...utils.data_utils import new_islice, all_gather +from ...metrics import Recall logger = logging.getLogger(__name__) diff --git a/one_peace/tasks/vision_tasks/image_classify.py b/one_peace/tasks/vision_tasks/image_classify.py index 2408623..9f54fae 100644 --- a/one_peace/tasks/vision_tasks/image_classify.py +++ b/one_peace/tasks/vision_tasks/image_classify.py @@ -10,9 +10,9 @@ from fairseq.tasks import register_task -from tasks.base_task import BaseTask, BaseTaskConfig -from data.vision_data.image_classify_dataset import ImageClassifyDataset -from metrics import Accuracy +from ..base_task import BaseTask, BaseTaskConfig +from ...data.vision_data.image_classify_dataset import ImageClassifyDataset +from ...metrics import Accuracy logger = logging.getLogger(__name__) diff --git a/one_peace/tasks/vl_tasks/image_text_retrieval.py b/one_peace/tasks/vl_tasks/image_text_retrieval.py index 79a3d3c..1d567e0 100644 --- a/one_peace/tasks/vl_tasks/image_text_retrieval.py +++ b/one_peace/tasks/vl_tasks/image_text_retrieval.py @@ -13,10 +13,10 @@ from fairseq.tasks import register_task from fairseq.utils import move_to_cuda -from tasks.base_task import BaseTask, BaseTaskConfig -from data.vl_data.image_text_retrieval_dataset import ImageTextRetrievalDataset -from utils.data_utils import new_islice, all_gather -from metrics import Recall +from ..base_task import BaseTask, BaseTaskConfig +from ...data.vl_data.image_text_retrieval_dataset import ImageTextRetrievalDataset +from ...utils.data_utils import new_islice, all_gather +from ...metrics import Recall logger = logging.getLogger(__name__) diff --git a/one_peace/tasks/vl_tasks/nlvr2.py b/one_peace/tasks/vl_tasks/nlvr2.py index 20b2d35..73b3b8b 100644 --- a/one_peace/tasks/vl_tasks/nlvr2.py +++ b/one_peace/tasks/vl_tasks/nlvr2.py @@ -9,9 +9,9 @@ from fairseq.tasks import register_task -from tasks.base_task import BaseTask, BaseTaskConfig -from data.vl_data.nlvr2_dataset import Nlvr2Dataset -from metrics import Accuracy +from ..base_task import BaseTask, BaseTaskConfig +from ...data.vl_data.nlvr2_dataset import Nlvr2Dataset +from ...metrics import Accuracy logger = logging.getLogger(__name__) diff --git a/one_peace/tasks/vl_tasks/refcoco.py b/one_peace/tasks/vl_tasks/refcoco.py index 07e6b23..f5ca22b 100644 --- a/one_peace/tasks/vl_tasks/refcoco.py +++ b/one_peace/tasks/vl_tasks/refcoco.py @@ -9,9 +9,9 @@ from fairseq.tasks import register_task -from tasks.base_task import BaseTask, BaseTaskConfig -from data.vl_data.refcoco_dataset import RefCOCODataset -from metrics import IouAcc +from ..base_task import BaseTask, BaseTaskConfig +from ...data.vl_data.refcoco_dataset import RefCOCODataset +from ...metrics import IouAcc logger = logging.getLogger(__name__) diff --git a/one_peace/tasks/vl_tasks/vqa.py b/one_peace/tasks/vl_tasks/vqa.py index e0b406c..20a8bb9 100644 --- a/one_peace/tasks/vl_tasks/vqa.py +++ b/one_peace/tasks/vl_tasks/vqa.py @@ -9,9 +9,9 @@ from fairseq.tasks import register_task -from tasks.base_task import BaseTask, BaseTaskConfig -from data.vl_data.vqa_dataset import VqaDataset -from metrics import Accuracy +from ..base_task import BaseTask, BaseTaskConfig +from ...data.vl_data.vqa_dataset import VqaDataset +from ...metrics import Accuracy logger = logging.getLogger(__name__) diff --git a/one_peace/train.py b/one_peace/train.py index faf201c..7a7f938 100644 --- a/one_peace/train.py +++ b/one_peace/train.py @@ -14,6 +14,8 @@ import sys from typing import Any, Dict, List, Optional, Tuple +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + # We need to setup root logger before importing any fairseq libraries. logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", @@ -39,7 +41,7 @@ from fairseq.model_parallel.megatron_trainer import MegatronTrainer # from fairseq.trainer import Trainer -from trainer import Trainer +from one_peace.trainer import Trainer def main(cfg: FairseqConfig) -> None: diff --git a/one_peace/trainer.py b/one_peace/trainer.py index af4a274..432b6c9 100644 --- a/one_peace/trainer.py +++ b/one_peace/trainer.py @@ -25,9 +25,9 @@ from fairseq.nan_detector import NanDetector from fairseq.optim import lr_scheduler -from utils.layer_decay import LayerDecayValueAssigner, get_parameter_groups -from utils.ema_module import EMAModule -from optim import FP16Optimizer, MemoryEfficientFP16Optimizer, AMPOptimizer +from .utils.layer_decay import LayerDecayValueAssigner, get_parameter_groups +from .utils.ema_module import EMAModule +from .optim import FP16Optimizer, MemoryEfficientFP16Optimizer, AMPOptimizer logger = logging.getLogger(__name__) diff --git a/one_peace/user_module/__init__.py b/one_peace/user_module/__init__.py index a931946..e66f934 100644 --- a/one_peace/user_module/__init__.py +++ b/one_peace/user_module/__init__.py @@ -1,7 +1,7 @@ -import criterions -import data -import metrics -import models -import optim -import tasks -import utils \ No newline at end of file +from one_peace import criterions +from one_peace import data +from one_peace import metrics +from one_peace import models +from one_peace import optim +from one_peace import tasks +from one_peace import utils \ No newline at end of file diff --git a/one_peace/utils/hub_interface.py b/one_peace/utils/hub_interface.py new file mode 100644 index 0000000..bb50ddd --- /dev/null +++ b/one_peace/utils/hub_interface.py @@ -0,0 +1,184 @@ + +import os +import urllib +import math +import librosa +from tqdm import tqdm +from PIL import Image + +import torch +import torch.nn.functional as F +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from fairseq import checkpoint_utils, utils + +from ..data.base_dataset import CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD +from ..utils.data_utils import collate_tokens +from .. import tasks +from .. import models + +_MODELS = { + "ONE-PEACE": "http://one-peace-shanghai.oss-cn-shanghai.aliyuncs.com/one-peace.pt" +} + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, + unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + return download_target + + +def from_pretrained( + model_name_or_path, + device=("cuda" if torch.cuda.is_available() else "cpu"), + dtype="float32", + download_root=None +): + + if os.path.isfile(model_name_or_path): + model_path = model_name_or_path + else: + model_path = _download(_MODELS[model_name_or_path], download_root or os.path.expanduser("~/.cache/one-peace")) + + # utils.import_user_module(argparse.Namespace(user_dir='../../user_module')) + overrides = {'model':{'_name':'one_peace_retrieval'}} + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( + [model_path], + arg_overrides=overrides + ) + model = models[0] + + return OnePeaceHubInterface(saved_cfg, task, model, device, dtype) + + +class OnePeaceHubInterface: + """A simple PyTorch Hub interface to ONE-PEACE.""" + + def __init__(self, cfg, task, model, device="cpu", dtype="float32"): + super().__init__() + self.model = model + self.device = device + self.dtype = dtype + + # for text + self.dict = task.dict + self.bpe = task.bpe + self.eos = self.dict.eos() + self.pad = self.dict.pad() + # for image + mean = CLIP_DEFAULT_MEAN + std = CLIP_DEFAULT_STD + self.transform = transforms.Compose([ + transforms.Resize((256, 256), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std) + ]) + # for audio + feature_encoder_spec = '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]' + self.feature_encoder_spec = eval(feature_encoder_spec) + self._features_size_map = {} + + self.model.to(device) + self.model.eval() + if self.dtype == "bf16": + self.model.bfloat16() + elif self.dtype == "fp16": + self.model.half() + else: + self.model.float() + + def cast_data_dtype(self, t): + if self.dtype == "bf16": + return t.to(dtype=torch.bfloat16) + elif self.dtype == "fp16": + return t.to(dtype=torch.half) + else: + return t + + def _get_mask_indices_dims(self, size, feature_encoder_spec, padding=0, dilation=1): + if size not in self._features_size_map: + L_in = size + for (_, kernel_size, stride) in feature_encoder_spec: + L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 + L_out = 1 + L_out // stride + L_in = L_out + self._features_size_map[size] = L_out + return self._features_size_map[size] + + def process_text(self, text_list): + tokens_list = [] + for text in text_list: + s = self.dict.encode_line( + line=self.bpe.encode(text), + add_if_not_exist=False, + append_eos=False + ).long() + s = s[:70] + s = torch.cat([s, torch.LongTensor([self.eos])]) + tokens_list.append(s) + src_tokens = collate_tokens(tokens_list, pad_idx=self.pad).to(self.device) + src_tokens = self.cast_data_dtype(src_tokens) + return src_tokens + + def process_image(self, image_list): + patch_images_list = [] + for image_path in image_list: + image = Image.open(image_path).convert("RGB") + patch_image = self.transform(image) + patch_images_list.append(patch_image) + src_images = torch.stack(patch_images_list, dim=0).to(self.device) + src_images = self.cast_data_dtype(src_images) + return src_images + + def process_audio(self, audio_list): + feats_list = [] + audio_padding_mask_list = [] + for audio in audio_list: + wav, curr_sample_rate = librosa.load(audio, sr=16000) + assert curr_sample_rate == 16000 + feats = torch.tensor(wav) + with torch.no_grad(): + feats = F.layer_norm(feats, feats.shape) + if feats.size(-1) > curr_sample_rate * 15: + start_idx = 0 + end_idx = start_idx + curr_sample_rate * 15 + feats = feats[start_idx:end_idx] + if feats.size(-1) < curr_sample_rate * 1: + feats = feats.repeat(math.ceil(curr_sample_rate * 1 / feats.size(-1))) + feats = feats[:curr_sample_rate * 1] + T = self._get_mask_indices_dims(feats.size(-1), self.feature_encoder_spec) + audio_padding_mask = torch.zeros(T + 1).bool() + feats_list.append(feats) + audio_padding_mask_list.append(audio_padding_mask) + src_audios = collate_tokens(feats_list, pad_idx=0).to(self.device) + src_audios = self.cast_data_dtype(src_audios) + audio_padding_masks = collate_tokens(audio_padding_mask_list, pad_idx=True).to(self.device) + return src_audios, audio_padding_masks + + def extract_text_features(self, src_tokens): + return self.model(src_tokens=src_tokens, encoder_type="text") + + def extract_image_features(self, src_images): + return self.model(src_images=src_images, encoder_type="image") + + def extract_audio_features(self, src_audios, audio_padding_masks): + return self.model(src_audios=src_audios, audio_padding_masks=audio_padding_masks, encoder_type="audio") diff --git a/requirements.txt b/requirements.txt index e9b93cd..816a308 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ --e ./fairseq/ +./fairseq/ opencv-python==4.7.0.72 requests==2.28.1 tensorboardX==2.6 @@ -9,4 +9,5 @@ pillow==8.4.0 timm==0.6.11 iopath==0.1.10 pydub==0.25.1 -scikit-learn==1.0.2 \ No newline at end of file +scikit-learn==1.0.2 +librosa==0.10.0 \ No newline at end of file