by Feng Hong, Jiangchao Yao, Zhihan Zhou, Ya Zhang, and Yanfeng Wang at SJTU and Shanghai AI Lab.
International Conference on Learning Representations (ICLR), 2023.
This repository is the official Pytorch implementation of RECORDS.
If you find our work inspiring or use our codebase in your research, please consider giving a star ⭐ and a citation.
@inproceedings{hong2023long,
title={Long-Tailed Partial Label Learning via Dynamic Rebalancing},
author={Hong, Feng and Yao, Jiangchao and Zhou, Zhihan and Zhang, Ya and Wang, Yanfeng},
booktitle={{ICLR}},
year={2023}
}
- We delve into a more practical but under-explored LT-PLL scenario, and identify its several challenges in this task that cannot be addressed and even lead to failure by the straightforward combination of the current long-tailed learning and partial label learning.
- We propose a novel RECORDS for LT-PLL that conducts the dynamic adjustment to rebalance the training without requiring any prior about the class distribution. The theoretical and empirical analysis show that the dynamic parametric class distribution is asymmetrically approaching to the oracle class distribution but more friendly to label disambiguation.
- Our method is orthogonal to existing PLL methods and can be easily plugged into the current PLL methods in an end-to-end manner.
The project is tested under the following environment settings:
- OS: Ubuntu 18.04.5
- GPU: NVIDIA GeForce RTX 3090
- Python: 3.7.10
- PyTorch: 1.7.1
- Torchvision: 0.8.2
- Cudatoolkit: 11.0.221
- Numpy: 1.21.2
After the preparation work, the whole project should have the following structure:
./RECORDS-LTPLL
├── README.md
├── models # models
│ ├── resnet.py
├── utils # utils: datasets, losses, etc.
│ ├── cifar10.py
│ ├── cifar100.py
│ ├── imbalance_cifar.py
│ ├── randaugment.py
│ ├── utils_algo.py
│ ├── utils_loss.py
├── utils_solar # utils for SoLar
│ ├── data.py
│ ├── resnet.py
│ ├── general.py
├── train.py # train for CORR (+ RECORDS)
└── train_solar.py # train for SoLar
For a self-training PLL loss:
# loss function (batch forwards)
loss = caculate_loss(logits, labels, self.confidence[index,:])
if update_target:
# disambiguation
self.confidence[index,:]=update_confidence(logits, self.confidence[index,:])
We can easily add RECORDS to the loss function:
# loss function (batch forwards)
loss = caculate_loss(logits, labels, self.confidence[index,:])
# momentum updates
if self.feat_mean is None:
self.feat_mean = 0.1*feat.detach().mean(0)
else:
self.feat_mean = 0.9*self.feat_mean + 0.1*feat.detach().mean(0)
if update_target:
# debias and disambiguation
bias = model.module.fc(self.feat_mean.unsqueeze(0)).detach()
bias = F.softmax(bias, dim=1)
logits_rebalanced = logits - torch.log(bias + 1e-9)
self.confidence[index,:]=update_confidence(logits_rebalanced, self.confidence[index,:])
For the CIFAR dataset, no additional data preparation is required. The first run will automatically download CIFAR to "./data".
Download the PLL version of PASCAL VOC 2007 and extract it to ". /data/VOC2017/". [Download (Google Drive)]
Run CORR[1] on CIFAR-10-LT with $q=0.3$ and Imbalance ratio $\rho = 0.01$
CUDA_VISIBLE_DEVICES=0 python -u train.py --exp_dir experiment/CORR-CIFAR-10 --dataset cifar10_im --num_class 10 --dist_url 'tcp://localhost:10000' --multiprocessing_distributed --world_size 1 --rank 0 --seed 123 --arch resnet18 --upd_start 1 --lr 0.01 --wd 1e-3 --cosine --epochs 800 --print_freq 100 --partial_rate 0.3 --imb_factor 0.01
CUDA_VISIBLE_DEVICES=0 python -u train.py --exp_dir experiment/CORR-CIFAR-10 --dataset cifar10_im --num_class 10 --dist_url 'tcp://localhost:10001' --multiprocessing_distributed --world_size 1 --rank 0 --seed 123 --arch resnet18 --upd_start 1 --lr 0.01 --wd 1e-3 --cosine --epochs 800 --print_freq 100 --partial_rate 0.3 --imb_factor 0.01 --records
Note: --records
means to apply RECORDS on the PLL baseline.
CUDA_VISIBLE_DEVICES=0 python -u train.py --exp_dir experiment/CORR-CIFAR-100 --dataset cifar100_im --num_class 100 --dist_url 'tcp://localhost:10002' --multiprocessing_distributed --world_size 1 --rank 0 --seed 123 --arch resnet18 --upd_start 1 --lr 0.01 --wd 1e-3 --cosine --epochs 800 --print_freq 100 --partial_rate 0.03 --imb_factor 0.01 --records --hierarchical
Note: --hierarchical
means using the non-uniform version of the dataset, i.e., CIFAR-100-LT-NU.
Run SoLar[2] (w/ Mixup) on CIFAR-10-LT with $q=0.3$ and Imbalance ratio $\rho = 0.01$
CUDA_VISIBLE_DEVICES=0 python -u train_solar.py --exp_dir experiment/SoLar-CIFAR-100 --dataset cifar10 --num_class 10 --partial_rate 0.3 --imb_type exp --imb_ratio 100 --est_epochs 100 --rho_range 0.2,0.6 --gamma 0.1,0.01 --epochs 800 --lr 0.01 --wd 1e-3 --cosine --seed 123
Note: SoLar is a concurrent LT-PLL work published in NeuIPS 2022. It improves the label disambiguation process in LT-PLL through the optimal transport technique. Different from SoLar, RECORDS tries to solve the LT-PLL problem from the perspective of rebalancing in a lightweight and effective manner.
Notes: On CIFAR-100-LT change these parameters to: --est_epochs 20 --rho_range 0.2,0.5 --gamma 0.05,0.01
.
CUDA_VISIBLE_DEVICES=0 python -u train.py --exp_dir experiment/CORR-CIFAR-10 --dataset cifar10_im --num_class 10 --dist_url 'tcp://localhost:10003' --multiprocessing_distributed --world_size 1 --rank 0 --seed 123 --arch resnet18 --upd_start 1 --lr 0.01 --wd 1e-3 --cosine --epochs 800 --print_freq 100 --partial_rate 0.3 --imb_factor 0.01 --records --mixup
Note: --mixup
means to use Mixup.
Imbalance ratio |
50 | 50 | 50 | 100 | 100 | 100 |
---|---|---|---|---|---|---|
ambiguity |
0.3 | 0.5 | 0.7 | 0.3 | 0.5 | 0.7 |
CORR | 76.12 | 56.45 | 41.56 | 66.38 | 50.09 | 38.11 |
CORR + Oracle-LA[3] | 36.27 | 17.61 | 12.77 | 29.97 | 15.80 | 11.75 |
CORR + RECORDS | 82.57 | 80.28 | 67.24 | 77.66 | 72.90 | 57.46 |
SoLar (w/ Mixup) | 83.88 | 76.55 | 54.61 | 75.38 | 70.63 | 53.15 |
CORR + RECORDS (w/ Mixup) | 84.25 | 82.5 | 71.24 | 79.79 | 74.07 | 62.25 |
Imbalance ratio |
50 | 50 | 50 | 100 | 100 | 100 |
---|---|---|---|---|---|---|
ambiguity |
0.03 | 0.05 | 0.07 | 0.03 | 0.05 | 0.07 |
CORR | 42.29 | 38.03 | 36.59 | 38.39 | 34.09 | 31.05 |
CORR + Oracle-LA | 22.56 | 5.59 | 3.12 | 11.37 | 3.32 | 1.98 |
CORR + RECORDS | 48.06 | 45.56 | 42.51 | 42.25 | 40.59 | 38.65 |
SoLar (w/ Mixup) | 47.93 | 46.85 | 45.1 | 42.51 | 41.71 | 39.15 |
CORR + RECORDS (w/ Mixup) | 52.08 | 50.58 | 47.91 | 46.57 | 45.22 | 44.73 |
- Add your model to "./models" and load the model in train.py.
- Implement functions(./utils/utils_loss.py) specfic to your models in train.py.
- Create the PLL version of the datasets and add to "./data".
- Implement the dataset (e.g., ./utils/cifar10.py).
- Load your data in train.py.
We borrow some codes from PiCO, LDAM-DRW, PRODEN, SADE, and SoLar.
[1] DD Wu, DB Wang, ML Zhang. Revisiting consistency regularization for deep partial label learning. ICML. 2022.
[2] H Wang, M Xia, Y Li, et al. SoLar: Sinkhorn Label Refinery for Imbalanced Partial-Label Learning. NeurIPS. 2022.
[3] AK Menon, S Jayasumana, AS Rawat, et al. Long-tail learning via logit adjustment. ICLR. 2021.
If you have any problem with this code, please feel free to contact feng.hong@sjtu.edu.cn.