-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
65 changed files
with
9,363 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
_debug* | ||
.env | ||
__pycache__ | ||
_sc.py | ||
*.ckpt | ||
*.bin | ||
|
||
checkpoints | ||
.idea | ||
.idea/workspace.xml | ||
.DS_Store | ||
*/__pycache__git | ||
.pyc | ||
.iml | ||
__pycache__/ | ||
*/__pycache__/ | ||
*/*/__pycache__/ | ||
*/*/*/__pycache__/ | ||
*/*/*/*/__pycache__/ | ||
*/*/*/*/*/__pycache__/ | ||
*/*/*/*/*/*/__pycache__/ |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,100 @@ | ||
# iPOSE | ||
Official pytorch implementation of "Inferring and Leveraging Parts from Object Shape for Improving Semantic Image Synthesis" (CVPR2023) | ||
# [Inferring and Leveraging Parts from Object Shape for Improving Semantic Image Synthesis](https://arxiv.org/abs/2305.19547) | ||
|
||
Code will be coming soon. | ||
--- | ||
|
||
|
||
## Method Details | ||
|
||
--- | ||
|
||
![teaser](assets/teaser.png) | ||
|
||
We propose a method iPOSE to infer parts from object shape and leverage them to improve semantic image synthesis. It can generate more photo-realistic parts from the given semantic map, while having the flexibility to control the generated objects | ||
|
||
## Quick Start | ||
|
||
--- | ||
|
||
### Environment Setup | ||
|
||
``` | ||
git clone https://github.com/csyxwei/iPOSE.git | ||
cd iPOSE | ||
conda create -n ipose python=3.7.6 | ||
conda activate ipose | ||
pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html | ||
pip install -r requirements.txt | ||
``` | ||
|
||
### Preparing Dataset | ||
|
||
Please follow the instructions to prepare COCO-Stuff, Cityscapes or ADE20K datasets as outlined in [SPADE](https://github.com/NVlabs/SPADE). Besides, we get the instance maps of ADE20K from [instancesegmentation](http://sceneparsing.csail.mit.edu/data/ChallengeData2017/annotations_instance.tar). | ||
|
||
The part dataset we used can be downloaded from [GoogleDrive](https://drive.google.com/file/d/1vZFrXQg1TnhMJh8c_g7o8bXJg5oV9nLC/view?usp=sharing). | ||
|
||
For COCO, we employ [./utils/coco_util/create_ins_dict.py](./utils/coco_util/create_ins_dict.py) to preprocess the instantce parameters for fast training. | ||
|
||
The final data structure is like this: | ||
|
||
``` | ||
datasets | ||
├── Cityscapes | ||
│ ├── leftImg8bit | ||
│ ├── gtFine | ||
│ ├── ... | ||
├── ADEChallengeData2016 | ||
│ ├── images | ||
│ ├── annotations | ||
│ ├── annotations_instance | ||
│ ├── ... | ||
├── COCO | ||
│ ├── train_img | ||
│ ├── train_label | ||
│ ├── train_inst | ||
│ ├── val_img | ||
│ ├── val_label | ||
│ ├── val_inst | ||
│ ├── train_dict.npy | ||
│ ├── val_dict.npy | ||
│ ├── ... | ||
├── Full_Parts | ||
│ ├── ... | ||
``` | ||
|
||
### Testing | ||
|
||
Downloading pretrained models from [GoogleDrive](https://drive.google.com/drive/folders/1Vz5j6PaLl_tPDacGTdJSVyiA8UH74Ftp?usp=sharing) and save them under `./checkpoints`. | ||
|
||
After that, you can use the provided testing scripts in [`scripts`](./scripts) for testing and evaluation. For example, | ||
|
||
``` | ||
bash scripts/test_ade20k_ipose.sh | ||
``` | ||
|
||
### Training | ||
|
||
You can use the provided training scripts in [`scripts`](./scripts) to train your model. For example, | ||
|
||
``` | ||
bash scripts/train_ade20k_ipose.sh | ||
``` | ||
|
||
## Citation | ||
|
||
--- | ||
|
||
``` | ||
@inproceedings{wei2023inferring, | ||
title={Inferring and leveraging parts from object shape for improving semantic image synthesis}, | ||
author={Wei, Yuxiang and Ji, Zhilong and Wu, Xiaohe and Bai, Jinfeng and Zhang, Lei and Zuo, Wangmeng}, | ||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, | ||
pages={11248--11258}, | ||
year={2023} | ||
} | ||
``` | ||
|
||
## Acknowledgments | ||
|
||
--- | ||
|
||
This code is built on [SPADE](https://github.com/NVlabs/SPADE) and [OASIS](https://github.com/boschresearch/OASIS). We thank the authors for sharing the codes. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
import argparse | ||
import pickle | ||
import os | ||
import utils.utils as utils | ||
# from models.biggan.utils import prepare_parser | ||
|
||
def read_arguments(train=True): | ||
parser = argparse.ArgumentParser() | ||
# parser = prepare_parser() | ||
parser = add_all_arguments(parser, train) | ||
parser.add_argument('--phase', type=str, default='train') | ||
opt = parser.parse_args() | ||
if train: | ||
set_dataset_default_lm(opt, parser) | ||
if opt.continue_train: | ||
update_options_from_file(opt, parser) | ||
opt = parser.parse_args() | ||
opt.phase = 'train' if train else 'test' | ||
if train: | ||
opt.loaded_latest_iter = 0 if not opt.continue_train else load_iter(opt) | ||
if opt.seed > -1: | ||
utils.fix_seed(opt.seed) | ||
print_options(opt, parser) | ||
if train: | ||
save_options(opt, parser) | ||
return opt | ||
|
||
|
||
def add_all_arguments(parser, train): | ||
#--- general options --- | ||
parser.add_argument('--name', type=str, default='label2coco', help='name of the experiment. It decides where to store samples and models') | ||
parser.add_argument('--seed', type=int, default=42, help='random seed') | ||
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') | ||
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') | ||
parser.add_argument('--no_spectral_norm', action='store_true', help='this option deactivates spectral norm in all layers') | ||
parser.add_argument('--batch_size', type=int, default=1, help='input batch size') | ||
parser.add_argument('--dataroot', type=str, default='./datasets/cityscapes/', help='path to dataset root') | ||
parser.add_argument('--dataset_mode', type=str, default='coco', help='this option indicates which dataset should be loaded') | ||
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') | ||
|
||
# for generator | ||
parser.add_argument('--num_res_blocks', type=int, default=6, help='number of residual blocks in G and D') | ||
parser.add_argument('--channels_G', type=int, default=64, help='# of gen filters in first conv layer in generator') | ||
parser.add_argument('--param_free_norm', type=str, default='syncbatch', help='which norm to use in generator before SPADE') | ||
parser.add_argument('--spade_ks', type=int, default=3, help='kernel size of convs inside SPADE') | ||
parser.add_argument('--no_EMA', action='store_true', help='if specified, do *not* compute exponential moving averages') | ||
parser.add_argument('--EMA_decay', type=float, default=0.9999, help='decay in exponential moving averages') | ||
parser.add_argument('--no_3dnoise', action='store_true', default=False, help='if specified, do *not* concatenate noise to label maps') | ||
parser.add_argument('--z_dim', type=int, default=64, help="dimension of the latent z vector") | ||
parser.add_argument('--use_spd', action='store_true', default=False, help='if specified, do *not* use sean block') | ||
parser.add_argument('--use_clip', action='store_true', default=False, help='if specified, do *not* use sean block') | ||
parser.add_argument('--use_edge', action='store_true', default=False, help='if specified, do *not* use sean block') | ||
parser.add_argument('--use_coord', action='store_true', default=False, help='if specified, do *not* use sean block') | ||
parser.add_argument('--no_sean', action='store_true', default=False, help='if specified, do *not* use sean block') | ||
parser.add_argument('--use_globalD', action='store_true', default=False, help='if specified, do *not* use sean block') | ||
parser.add_argument('--part_nc', type=int, default=0, help='# of part') | ||
parser.add_argument('--n_support', type=int, default=3, help='# of support') | ||
parser.add_argument('--n_att_layers', type=int, default=3, help='# of support') | ||
if train: | ||
parser.add_argument('--freq_print', type=int, default=1000, help='frequency of showing training results') | ||
parser.add_argument('--freq_save_ckpt', type=int, default=20000, help='frequency of saving the checkpoints') | ||
parser.add_argument('--freq_save_latest', type=int, default=10000, help='frequency of saving the latest model') | ||
parser.add_argument('--freq_smooth_loss', type=int, default=250, help='smoothing window for loss visualization') | ||
parser.add_argument('--freq_save_loss', type=int, default=2500, help='frequency of loss plot updates') | ||
parser.add_argument('--freq_fid', type=int, default=5000, help='frequency of saving the fid score (in training iterations)') | ||
parser.add_argument('--continue_train', action='store_true', help='resume previously interrupted training') | ||
parser.add_argument('--which_iter', type=str, default='latest', help='which epoch to load when continue_train') | ||
parser.add_argument('--num_epochs', type=int, default=200, help='number of epochs to train') | ||
parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam') | ||
parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') | ||
parser.add_argument('--lr_g', type=float, default=0.0001, help='G learning rate, default=0.0001') | ||
parser.add_argument('--lr_d', type=float, default=0.0004, help='D learning rate, default=0.0004') | ||
|
||
parser.add_argument('--channels_D', type=int, default=64, help='# of discrim filters in first conv layer in discriminator') | ||
parser.add_argument('--add_vgg_loss', action='store_true', help='if specified, add VGG feature matching loss') | ||
parser.add_argument('--lambda_vgg', type=float, default=10.0, help='weight for VGG loss') | ||
parser.add_argument('--no_balancing_inloss', action='store_true', default=False, help='if specified, do *not* use class balancing in the loss function') | ||
parser.add_argument('--no_labelmix', action='store_true', default=False, help='if specified, do *not* use LabelMix') | ||
parser.add_argument('--lambda_labelmix', type=float, default=10.0, help='weight for LabelMix regularization') | ||
|
||
parser.add_argument('--results_dir', type=str, default='./results/', help='saves testing results here.') | ||
parser.add_argument('--ckpt_iter', type=str, default='latest', help='which epoch to load to evaluate a model') | ||
parser.add_argument('--test_batch', type=int, default=4, help='test batch') | ||
else: | ||
parser.add_argument('--results_dir', type=str, default='./results/', help='saves testing results here.') | ||
parser.add_argument('--ckpt_iter', type=str, default='best', help='which epoch to load to evaluate a model') | ||
return parser | ||
|
||
|
||
def set_dataset_default_lm(opt, parser): | ||
if 'ade20k' in opt.dataset_mode: | ||
parser.set_defaults(lambda_labelmix=10.0) | ||
parser.set_defaults(EMA_decay=0.9999) | ||
if 'cityscapes' in opt.dataset_mode: | ||
parser.set_defaults(lr_g=0.0004) | ||
parser.set_defaults(lambda_labelmix=5.0) | ||
parser.set_defaults(freq_fid=2500) | ||
parser.set_defaults(EMA_decay=0.999) | ||
if "coco" in opt.dataset_mode: | ||
parser.set_defaults(lambda_labelmix=10.0) | ||
parser.set_defaults(EMA_decay=0.9999) | ||
parser.set_defaults(num_epochs=100) | ||
|
||
|
||
def save_options(opt, parser): | ||
path_name = os.path.join(opt.checkpoints_dir,opt.name) | ||
os.makedirs(path_name, exist_ok=True) | ||
with open(path_name + '/opt.txt', 'wt') as opt_file: | ||
for k, v in sorted(vars(opt).items()): | ||
comment = '' | ||
default = parser.get_default(k) | ||
if v != default: | ||
comment = '\t[default: %s]' % str(default) | ||
opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)) | ||
|
||
with open(path_name + '/opt.pkl', 'wb') as opt_file: | ||
pickle.dump(opt, opt_file) | ||
|
||
|
||
def update_options_from_file(opt, parser): | ||
new_opt = load_options(opt) | ||
for k, v in sorted(vars(opt).items()): | ||
if hasattr(new_opt, k) and v != getattr(new_opt, k): | ||
new_val = getattr(new_opt, k) | ||
parser.set_defaults(**{k: new_val}) | ||
return parser | ||
|
||
|
||
def load_options(opt): | ||
file_name = os.path.join(opt.checkpoints_dir, opt.name, "opt.pkl") | ||
new_opt = pickle.load(open(file_name, 'rb')) | ||
return new_opt | ||
|
||
|
||
def load_iter(opt): | ||
if opt.which_iter == "latest": | ||
with open(os.path.join(opt.checkpoints_dir, opt.name, "latest_iter.txt"), "r") as f: | ||
res = int(f.read()) | ||
return res | ||
elif opt.which_iter == "best": | ||
with open(os.path.join(opt.checkpoints_dir, opt.name, "best_iter.txt"), "r") as f: | ||
res = int(f.read()) | ||
return res | ||
else: | ||
return int(opt.which_iter) | ||
|
||
|
||
def print_options(opt, parser): | ||
message = '' | ||
message += '----------------- Options ---------------\n' | ||
for k, v in sorted(vars(opt).items()): | ||
comment = '' | ||
default = parser.get_default(k) | ||
if v != default: | ||
comment = '\t[default: %s]' % str(default) | ||
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) | ||
message += '----------------- End -------------------' | ||
print(message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import random | ||
import torch | ||
from torchvision import transforms as TR | ||
import os | ||
from PIL import Image | ||
import numpy as np | ||
|
||
|
||
class Ade20kDataset(torch.utils.data.Dataset): | ||
def __init__(self, opt, for_metrics): | ||
if opt.phase == "test" or for_metrics: | ||
opt.load_size = 256 | ||
else: | ||
opt.load_size = 286 | ||
opt.crop_size = 256 | ||
opt.label_nc = 150 | ||
opt.contain_dontcare_label = True | ||
opt.semantic_nc = 151 # label_nc + unknown | ||
opt.cache_filelist_read = False | ||
opt.cache_filelist_write = False | ||
opt.aspect_ratio = 1.0 | ||
|
||
self.opt = opt | ||
self.for_metrics = for_metrics | ||
self.images, self.labels, self.paths = self.list_images() | ||
|
||
def __len__(self,): | ||
return len(self.images) | ||
|
||
def __getitem__(self, idx): | ||
image = Image.open(os.path.join(self.paths[0], self.images[idx])).convert('RGB') | ||
label = Image.open(os.path.join(self.paths[1], self.labels[idx])) | ||
image, label = self.transforms(image, label) | ||
label = label * 255 | ||
return {"image": image, "label": label, "name": self.images[idx]} | ||
|
||
def list_images(self): | ||
mode = "validation" if self.opt.phase == "test" or self.for_metrics else "training" | ||
path_img = os.path.join(self.opt.dataroot, "images", mode) | ||
path_lab = os.path.join(self.opt.dataroot, "annotations", mode) | ||
img_list = os.listdir(path_img) | ||
lab_list = os.listdir(path_lab) | ||
img_list = [filename for filename in img_list if ".png" in filename or ".jpg" in filename] | ||
lab_list = [filename for filename in lab_list if ".png" in filename or ".jpg" in filename] | ||
images = sorted(img_list) | ||
labels = sorted(lab_list) | ||
assert len(images) == len(labels), "different len of images and labels %s - %s" % (len(images), len(labels)) | ||
for i in range(len(images)): | ||
assert os.path.splitext(images[i])[0] == os.path.splitext(labels[i])[0], '%s and %s are not matching' % (images[i], labels[i]) | ||
return images, labels, (path_img, path_lab) | ||
|
||
def transforms(self, image, label): | ||
assert image.size == label.size | ||
# resize | ||
new_width, new_height = (self.opt.load_size, self.opt.load_size) | ||
image = TR.functional.resize(image, (new_width, new_height), Image.BICUBIC) | ||
label = TR.functional.resize(label, (new_width, new_height), Image.NEAREST) | ||
# crop | ||
crop_x = random.randint(0, np.maximum(0, new_width - self.opt.crop_size)) | ||
crop_y = random.randint(0, np.maximum(0, new_height - self.opt.crop_size)) | ||
image = image.crop((crop_x, crop_y, crop_x + self.opt.crop_size, crop_y + self.opt.crop_size)) | ||
label = label.crop((crop_x, crop_y, crop_x + self.opt.crop_size, crop_y + self.opt.crop_size)) | ||
# flip | ||
if not (self.opt.phase == "test" or self.opt.no_flip or self.for_metrics): | ||
if random.random() < 0.5: | ||
image = TR.functional.hflip(image) | ||
label = TR.functional.hflip(label) | ||
# to tensor | ||
image = TR.functional.to_tensor(image) | ||
label = TR.functional.to_tensor(label) | ||
# normalize | ||
image = TR.functional.normalize(image, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | ||
return image, label |
Oops, something went wrong.