Skip to content

Commit

Permalink
release codes
Browse files Browse the repository at this point in the history
  • Loading branch information
csyxwei committed Dec 13, 2024
1 parent 7a50532 commit 2a78b7b
Show file tree
Hide file tree
Showing 65 changed files with 9,363 additions and 3 deletions.
21 changes: 21 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
_debug*
.env
__pycache__
_sc.py
*.ckpt
*.bin

checkpoints
.idea
.idea/workspace.xml
.DS_Store
*/__pycache__git
.pyc
.iml
__pycache__/
*/__pycache__/
*/*/__pycache__/
*/*/*/__pycache__/
*/*/*/*/__pycache__/
*/*/*/*/*/__pycache__/
*/*/*/*/*/*/__pycache__/
694 changes: 694 additions & 0 deletions 3rd-party-licenses.txt

Large diffs are not rendered by default.

661 changes: 661 additions & 0 deletions LICENSE

Large diffs are not rendered by default.

102 changes: 99 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,100 @@
# iPOSE
Official pytorch implementation of "Inferring and Leveraging Parts from Object Shape for Improving Semantic Image Synthesis" (CVPR2023)
# [Inferring and Leveraging Parts from Object Shape for Improving Semantic Image Synthesis](https://arxiv.org/abs/2305.19547)

Code will be coming soon.
---


## Method Details

---

![teaser](assets/teaser.png)

We propose a method iPOSE to infer parts from object shape and leverage them to improve semantic image synthesis. It can generate more photo-realistic parts from the given semantic map, while having the flexibility to control the generated objects

## Quick Start

---

### Environment Setup

```
git clone https://github.com/csyxwei/iPOSE.git
cd iPOSE
conda create -n ipose python=3.7.6
conda activate ipose
pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt
```

### Preparing Dataset

Please follow the instructions to prepare COCO-Stuff, Cityscapes or ADE20K datasets as outlined in [SPADE](https://github.com/NVlabs/SPADE). Besides, we get the instance maps of ADE20K from [instancesegmentation](http://sceneparsing.csail.mit.edu/data/ChallengeData2017/annotations_instance.tar).

The part dataset we used can be downloaded from [GoogleDrive](https://drive.google.com/file/d/1vZFrXQg1TnhMJh8c_g7o8bXJg5oV9nLC/view?usp=sharing).

For COCO, we employ [./utils/coco_util/create_ins_dict.py](./utils/coco_util/create_ins_dict.py) to preprocess the instantce parameters for fast training.

The final data structure is like this:

```
datasets
├── Cityscapes
│ ├── leftImg8bit
│ ├── gtFine
│ ├── ...
├── ADEChallengeData2016
│ ├── images
│ ├── annotations
│ ├── annotations_instance
│ ├── ...
├── COCO
│ ├── train_img
│ ├── train_label
│ ├── train_inst
│ ├── val_img
│ ├── val_label
│ ├── val_inst
│ ├── train_dict.npy
│ ├── val_dict.npy
│ ├── ...
├── Full_Parts
│ ├── ...
```

### Testing

Downloading pretrained models from [GoogleDrive](https://drive.google.com/drive/folders/1Vz5j6PaLl_tPDacGTdJSVyiA8UH74Ftp?usp=sharing) and save them under `./checkpoints`.

After that, you can use the provided testing scripts in [`scripts`](./scripts) for testing and evaluation. For example,

```
bash scripts/test_ade20k_ipose.sh
```

### Training

You can use the provided training scripts in [`scripts`](./scripts) to train your model. For example,

```
bash scripts/train_ade20k_ipose.sh
```

## Citation

---

```
@inproceedings{wei2023inferring,
title={Inferring and leveraging parts from object shape for improving semantic image synthesis},
author={Wei, Yuxiang and Ji, Zhilong and Wu, Xiaohe and Bai, Jinfeng and Zhang, Lei and Zuo, Wangmeng},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={11248--11258},
year={2023}
}
```

## Acknowledgments

---

This code is built on [SPADE](https://github.com/NVlabs/SPADE) and [OASIS](https://github.com/boschresearch/OASIS). We thank the authors for sharing the codes.
Binary file added assets/teaser.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
158 changes: 158 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import argparse
import pickle
import os
import utils.utils as utils
# from models.biggan.utils import prepare_parser

def read_arguments(train=True):
parser = argparse.ArgumentParser()
# parser = prepare_parser()
parser = add_all_arguments(parser, train)
parser.add_argument('--phase', type=str, default='train')
opt = parser.parse_args()
if train:
set_dataset_default_lm(opt, parser)
if opt.continue_train:
update_options_from_file(opt, parser)
opt = parser.parse_args()
opt.phase = 'train' if train else 'test'
if train:
opt.loaded_latest_iter = 0 if not opt.continue_train else load_iter(opt)
if opt.seed > -1:
utils.fix_seed(opt.seed)
print_options(opt, parser)
if train:
save_options(opt, parser)
return opt


def add_all_arguments(parser, train):
#--- general options ---
parser.add_argument('--name', type=str, default='label2coco', help='name of the experiment. It decides where to store samples and models')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
parser.add_argument('--no_spectral_norm', action='store_true', help='this option deactivates spectral norm in all layers')
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--dataroot', type=str, default='./datasets/cityscapes/', help='path to dataset root')
parser.add_argument('--dataset_mode', type=str, default='coco', help='this option indicates which dataset should be loaded')
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')

# for generator
parser.add_argument('--num_res_blocks', type=int, default=6, help='number of residual blocks in G and D')
parser.add_argument('--channels_G', type=int, default=64, help='# of gen filters in first conv layer in generator')
parser.add_argument('--param_free_norm', type=str, default='syncbatch', help='which norm to use in generator before SPADE')
parser.add_argument('--spade_ks', type=int, default=3, help='kernel size of convs inside SPADE')
parser.add_argument('--no_EMA', action='store_true', help='if specified, do *not* compute exponential moving averages')
parser.add_argument('--EMA_decay', type=float, default=0.9999, help='decay in exponential moving averages')
parser.add_argument('--no_3dnoise', action='store_true', default=False, help='if specified, do *not* concatenate noise to label maps')
parser.add_argument('--z_dim', type=int, default=64, help="dimension of the latent z vector")
parser.add_argument('--use_spd', action='store_true', default=False, help='if specified, do *not* use sean block')
parser.add_argument('--use_clip', action='store_true', default=False, help='if specified, do *not* use sean block')
parser.add_argument('--use_edge', action='store_true', default=False, help='if specified, do *not* use sean block')
parser.add_argument('--use_coord', action='store_true', default=False, help='if specified, do *not* use sean block')
parser.add_argument('--no_sean', action='store_true', default=False, help='if specified, do *not* use sean block')
parser.add_argument('--use_globalD', action='store_true', default=False, help='if specified, do *not* use sean block')
parser.add_argument('--part_nc', type=int, default=0, help='# of part')
parser.add_argument('--n_support', type=int, default=3, help='# of support')
parser.add_argument('--n_att_layers', type=int, default=3, help='# of support')
if train:
parser.add_argument('--freq_print', type=int, default=1000, help='frequency of showing training results')
parser.add_argument('--freq_save_ckpt', type=int, default=20000, help='frequency of saving the checkpoints')
parser.add_argument('--freq_save_latest', type=int, default=10000, help='frequency of saving the latest model')
parser.add_argument('--freq_smooth_loss', type=int, default=250, help='smoothing window for loss visualization')
parser.add_argument('--freq_save_loss', type=int, default=2500, help='frequency of loss plot updates')
parser.add_argument('--freq_fid', type=int, default=5000, help='frequency of saving the fid score (in training iterations)')
parser.add_argument('--continue_train', action='store_true', help='resume previously interrupted training')
parser.add_argument('--which_iter', type=str, default='latest', help='which epoch to load when continue_train')
parser.add_argument('--num_epochs', type=int, default=200, help='number of epochs to train')
parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam')
parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
parser.add_argument('--lr_g', type=float, default=0.0001, help='G learning rate, default=0.0001')
parser.add_argument('--lr_d', type=float, default=0.0004, help='D learning rate, default=0.0004')

parser.add_argument('--channels_D', type=int, default=64, help='# of discrim filters in first conv layer in discriminator')
parser.add_argument('--add_vgg_loss', action='store_true', help='if specified, add VGG feature matching loss')
parser.add_argument('--lambda_vgg', type=float, default=10.0, help='weight for VGG loss')
parser.add_argument('--no_balancing_inloss', action='store_true', default=False, help='if specified, do *not* use class balancing in the loss function')
parser.add_argument('--no_labelmix', action='store_true', default=False, help='if specified, do *not* use LabelMix')
parser.add_argument('--lambda_labelmix', type=float, default=10.0, help='weight for LabelMix regularization')

parser.add_argument('--results_dir', type=str, default='./results/', help='saves testing results here.')
parser.add_argument('--ckpt_iter', type=str, default='latest', help='which epoch to load to evaluate a model')
parser.add_argument('--test_batch', type=int, default=4, help='test batch')
else:
parser.add_argument('--results_dir', type=str, default='./results/', help='saves testing results here.')
parser.add_argument('--ckpt_iter', type=str, default='best', help='which epoch to load to evaluate a model')
return parser


def set_dataset_default_lm(opt, parser):
if 'ade20k' in opt.dataset_mode:
parser.set_defaults(lambda_labelmix=10.0)
parser.set_defaults(EMA_decay=0.9999)
if 'cityscapes' in opt.dataset_mode:
parser.set_defaults(lr_g=0.0004)
parser.set_defaults(lambda_labelmix=5.0)
parser.set_defaults(freq_fid=2500)
parser.set_defaults(EMA_decay=0.999)
if "coco" in opt.dataset_mode:
parser.set_defaults(lambda_labelmix=10.0)
parser.set_defaults(EMA_decay=0.9999)
parser.set_defaults(num_epochs=100)


def save_options(opt, parser):
path_name = os.path.join(opt.checkpoints_dir,opt.name)
os.makedirs(path_name, exist_ok=True)
with open(path_name + '/opt.txt', 'wt') as opt_file:
for k, v in sorted(vars(opt).items()):
comment = ''
default = parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment))

with open(path_name + '/opt.pkl', 'wb') as opt_file:
pickle.dump(opt, opt_file)


def update_options_from_file(opt, parser):
new_opt = load_options(opt)
for k, v in sorted(vars(opt).items()):
if hasattr(new_opt, k) and v != getattr(new_opt, k):
new_val = getattr(new_opt, k)
parser.set_defaults(**{k: new_val})
return parser


def load_options(opt):
file_name = os.path.join(opt.checkpoints_dir, opt.name, "opt.pkl")
new_opt = pickle.load(open(file_name, 'rb'))
return new_opt


def load_iter(opt):
if opt.which_iter == "latest":
with open(os.path.join(opt.checkpoints_dir, opt.name, "latest_iter.txt"), "r") as f:
res = int(f.read())
return res
elif opt.which_iter == "best":
with open(os.path.join(opt.checkpoints_dir, opt.name, "best_iter.txt"), "r") as f:
res = int(f.read())
return res
else:
return int(opt.which_iter)


def print_options(opt, parser):
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
73 changes: 73 additions & 0 deletions dataloaders/Ade20kDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import random
import torch
from torchvision import transforms as TR
import os
from PIL import Image
import numpy as np


class Ade20kDataset(torch.utils.data.Dataset):
def __init__(self, opt, for_metrics):
if opt.phase == "test" or for_metrics:
opt.load_size = 256
else:
opt.load_size = 286
opt.crop_size = 256
opt.label_nc = 150
opt.contain_dontcare_label = True
opt.semantic_nc = 151 # label_nc + unknown
opt.cache_filelist_read = False
opt.cache_filelist_write = False
opt.aspect_ratio = 1.0

self.opt = opt
self.for_metrics = for_metrics
self.images, self.labels, self.paths = self.list_images()

def __len__(self,):
return len(self.images)

def __getitem__(self, idx):
image = Image.open(os.path.join(self.paths[0], self.images[idx])).convert('RGB')
label = Image.open(os.path.join(self.paths[1], self.labels[idx]))
image, label = self.transforms(image, label)
label = label * 255
return {"image": image, "label": label, "name": self.images[idx]}

def list_images(self):
mode = "validation" if self.opt.phase == "test" or self.for_metrics else "training"
path_img = os.path.join(self.opt.dataroot, "images", mode)
path_lab = os.path.join(self.opt.dataroot, "annotations", mode)
img_list = os.listdir(path_img)
lab_list = os.listdir(path_lab)
img_list = [filename for filename in img_list if ".png" in filename or ".jpg" in filename]
lab_list = [filename for filename in lab_list if ".png" in filename or ".jpg" in filename]
images = sorted(img_list)
labels = sorted(lab_list)
assert len(images) == len(labels), "different len of images and labels %s - %s" % (len(images), len(labels))
for i in range(len(images)):
assert os.path.splitext(images[i])[0] == os.path.splitext(labels[i])[0], '%s and %s are not matching' % (images[i], labels[i])
return images, labels, (path_img, path_lab)

def transforms(self, image, label):
assert image.size == label.size
# resize
new_width, new_height = (self.opt.load_size, self.opt.load_size)
image = TR.functional.resize(image, (new_width, new_height), Image.BICUBIC)
label = TR.functional.resize(label, (new_width, new_height), Image.NEAREST)
# crop
crop_x = random.randint(0, np.maximum(0, new_width - self.opt.crop_size))
crop_y = random.randint(0, np.maximum(0, new_height - self.opt.crop_size))
image = image.crop((crop_x, crop_y, crop_x + self.opt.crop_size, crop_y + self.opt.crop_size))
label = label.crop((crop_x, crop_y, crop_x + self.opt.crop_size, crop_y + self.opt.crop_size))
# flip
if not (self.opt.phase == "test" or self.opt.no_flip or self.for_metrics):
if random.random() < 0.5:
image = TR.functional.hflip(image)
label = TR.functional.hflip(label)
# to tensor
image = TR.functional.to_tensor(image)
label = TR.functional.to_tensor(label)
# normalize
image = TR.functional.normalize(image, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
return image, label
Loading

0 comments on commit 2a78b7b

Please sign in to comment.