Skip to content

Commit

Permalink
[Feature]: Add pixmim algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanLiuuuuuu committed Mar 23, 2023
1 parent ae03d92 commit 03f6b8a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mmselfsup/models/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
from .simmim import SimMIM
from .simsiam import SimSiam
from .swav import SwAV
from .pixmim import PixMIM

__all__ = [
'BaseModel', 'BarlowTwins', 'BEiT', 'BYOL', 'DeepCluster', 'DenseCL',
'MoCo', 'NPID', 'ODC', 'RelativeLoc', 'RotationPred', 'SimCLR', 'SimSiam',
'SwAV', 'MAE', 'MoCoV3', 'SimMIM', 'CAE', 'MaskFeat', 'MILAN', 'EVA',
'MixMIM'
'MixMIM', 'PixMIM'
]
39 changes: 39 additions & 0 deletions mmselfsup/models/algorithms/pixmim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List

import torch

from mmselfsup.registry import MODELS
from mmselfsup.structures import SelfSupDataSample
from .mae import MAE


@MODELS.register_module()
class PixMIM(MAE):
"""The official implementation of PixMIM.
Implementation of `PixMIM: Rethinking Pixel Reconstruction in
Masked Image Modeling <https://arxiv.org/pdf/2303.02416.pdf>`_.
Please refer to MAE for these initialization arguments.
"""

def loss(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
# ids_restore: the same as that in original repo, which is used
# to recover the original order of tokens in decoder.
low_freq_targets = self.target_generator(inputs[0])
latent, mask, ids_restore = self.backbone(inputs[0])
pred = self.neck(latent, ids_restore)
loss = self.head(pred, low_freq_targets, mask)
losses = dict(loss=loss)
return losses

0 comments on commit 03f6b8a

Please sign in to comment.