Skip to content

Commit

Permalink
Merge branch 'idx_more_setting' into 'master'
Browse files Browse the repository at this point in the history
Update model url and add three more setting

See merge request open-mmlab/mmediting!349
  • Loading branch information
wangxintao committed Jul 8, 2020
2 parents 3db5309 + 6b0766d commit 7e07b08
Show file tree
Hide file tree
Showing 8 changed files with 405 additions and 7 deletions.
6 changes: 3 additions & 3 deletions configs/mattors/dim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
|:----------:|:----:|:-----:|:----:|:----:|:--------:|
| stage1 (paper) | 54.6 | 0.017 | 36.7 | 55.3 | - |
| stage3 (paper) | **50.4** | **0.014** | 31.0 | 50.8 | - |
| stage1 (our) | 53.8 | 0.017 | 32.7 | 54.5 | [model](TODO) \| [log](TODO) |
| stage2 (our) | 52.3 | 0.016 | 29.4 | 52.4 | [model](TODO) \| [log](TODO) |
| stage3 (our) | 50.6 | 0.015 | **29.0** | **50.7** | [model](TODO) \| [log](TODO) |
| stage1 (our) | 53.8 | 0.017 | 32.7 | 54.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmediting/v0.1/mattors/dim/dim_stage1_v16_1x1_1000k_comp1k_SAD-53.8_20200605_140257-979a420f.pth) \| [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmediting/v0.1/mattors/dim/dim_stage1_v16_1x1_1000k_comp1k_20200605_140257.log.json) |
| stage2 (our) | 52.3 | 0.016 | 29.4 | 52.4 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmediting/v0.1/mattors/dim/dim_stage2_v16_pln_1x1_1000k_comp1k_SAD-52.3_20200607_171909-d83c4775.pth) \| [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmediting/v0.1/mattors/dim/dim_stage2_v16_pln_1x1_1000k_comp1k_20200607_171909.log.json) |
| stage3 (our) | 50.6 | 0.015 | **29.0** | **50.7** | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmediting/v0.1/mattors/dim/dim_stage3_v16_pln_1x1_1000k_comp1k_SAD-50.6_20200609_111851-647f24b6.pth) \| [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmediting/v0.1/mattors/dim/dim_stage3_v16_pln_1x1_1000k_comp1k_20200609_111851.log.json) |

**NOTE**

Expand Down
11 changes: 9 additions & 2 deletions configs/mattors/gca/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,12 @@
|:----------:|:-----:|:------:|:-----:|:-----:|:--------:|
| baseline (paper) | 40.62 | 0.0106 | 21.53 | 38.43 | - |
| GCA (paper) | 35.28 | 0.0091 | 16.92 | 32.53 | - |
| baseline (our) | 36.50 | 0.0090 | 17.40 | 34.33 | [model](TODO) \| [log](TODO) |
| GCA (our) | **34.77** | **0.0080** | **16.33** | **32.20** | [model](TODO) \| [log](TODO) |
| baseline (our) | 36.50 | 0.0090 | 17.40 | 34.33 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmediting/v0.1/mattors/gca/baseline_r34_4x10_200k_comp1k_SAD-36.50_20200614_105701-95be1750.pth) \| [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmediting/v0.1/mattors/gca/baseline_r34_4x10_200k_comp1k_20200614_105701.log.json) |
| GCA (our) | **34.77** | **0.0080** | **16.33** | **32.20** | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmediting/v0.1/mattors/gca/gca_r34_4x10_200k_comp1k_SAD-34.77_20200604_213848-4369bea0.pth) \| [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmediting/v0.1/mattors/gca/gca_r34_4x10_200k_comp1k_20200604_213848.log.json) |

### More results

| Method | SAD | MSE | GRAD | CONN | Download |
|:----------:|:-----:|:------:|:-----:|:-----:|:--------:|
| baseline (with DIM pipeline) | 49.95 | 0.0144 | 30.21 | 49.67 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmediting/v0.1/mattors/gca/TODO_to_be_added) \| [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmediting/v0.1/mattors/gca/TODO_to_be_added) |
| GCA (with DIM pipeline) | 49.42 | 0.0129 | 28.07 | 49.47 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmediting/v0.1/mattors/gca/TODO_to_be_added) \| [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmediting/v0.1/mattors/gca/TODO_to_be_added) |
125 changes: 125 additions & 0 deletions configs/mattors/gca/baseline_dimaug_r34_4x10_200k_comp1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# model settings
model = dict(
type='GCA',
backbone=dict(
type='SimpleEncoderDecoder',
encoder=dict(
type='ResShortcutEnc',
block='BasicBlock',
layers=[3, 4, 4, 2],
in_channels=4,
with_spectral_norm=True),
decoder=dict(
type='ResShortcutDec',
block='BasicBlockDec',
layers=[2, 3, 3, 2],
with_spectral_norm=True)),
loss_alpha=dict(type='L1Loss'),
pretrained='open-mmlab://mmedit/res34_en_nomixup')
train_cfg = dict(train_backbone=True)
test_cfg = dict(metrics=['SAD', 'MSE', 'GRAD', 'CONN'])

# dataset settings
dataset_type = 'AdobeComp1kDataset'
data_root = './data/adobe_composition-1k/'
bg_dir = './data/coco/train2017'
img_norm_cfg = dict(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile', key='alpha', flag='grayscale'),
dict(type='LoadImageFromFile', key='merged'),
dict(
type='CropAroundUnknown',
keys=['alpha', 'merged'],
crop_sizes=[320, 480, 640]),
dict(type='Flip', keys=['alpha', 'merged']),
dict(
type='Resize',
keys=['alpha', 'merged'],
scale=(320, 320),
keep_ratio=False),
dict(type='GenerateTrimap', kernel_size=(1, 30)),
dict(type='RescaleToZeroOne', keys=['merged', 'alpha']),
dict(type='Normalize', keys=['merged'], **img_norm_cfg),
dict(type='Collect', keys=['merged', 'alpha', 'trimap'], meta_keys=[]),
dict(type='ImageToTensor', keys=['merged', 'alpha', 'trimap']),
dict(type='FormatTrimap', to_onehot=False),
]
test_pipeline = [
dict(
type='LoadImageFromFile',
key='alpha',
flag='grayscale',
save_original_img=True),
dict(
type='LoadImageFromFile',
key='trimap',
flag='grayscale',
save_original_img=True),
dict(type='LoadImageFromFile', key='merged'),
dict(type='Pad', keys=['trimap', 'merged'], mode='reflect'),
dict(type='RescaleToZeroOne', keys=['merged']),
dict(type='Normalize', keys=['merged'], **img_norm_cfg),
dict(
type='Collect',
keys=['merged', 'trimap'],
meta_keys=[
'merged_path', 'pad', 'merged_ori_shape', 'ori_alpha', 'ori_trimap'
]),
dict(type='ImageToTensor', keys=['merged', 'trimap']),
dict(type='FormatTrimap', to_onehot=False),
]
data = dict(
samples_per_gpu=10,
workers_per_gpu=4,
val_samples_per_gpu=1,
val_workers_per_gpu=4,
drop_last=True,
train=dict(
type=dataset_type,
ann_file=data_root + 'training_list.json',
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'test_list.json',
data_prefix=data_root,
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'test_list.json',
data_prefix=data_root,
pipeline=test_pipeline))

# optimizer
optimizers = dict(type='Adam', lr=4e-4, betas=[0.5, 0.999])
# learning policy
lr_config = dict(
policy='CosineAnealing',
min_lr=0,
by_epoch=False,
warmup='linear',
warmup_iters=5000,
warmup_ratio=0.001)

# checkpoint saving
checkpoint_config = dict(interval=2000, by_epoch=False)
evaluation = dict(interval=2000, save_image=False, gpu_collect=False)
# yapf:disable
log_config = dict(
interval=10,
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
# dict(type='TensorboardLoggerHook'),
# dict(type='PaviLoggerHook', init_kwargs=dict(project='gca'))
])
# yapf:enable

# runtime settings
total_iters = 200000
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/shortcut'
load_from = None
resume_from = None
workflow = [('train', 1)]
125 changes: 125 additions & 0 deletions configs/mattors/gca/gca_dimaug_r34_4x10_200k_comp1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# model settings
model = dict(
type='GCA',
backbone=dict(
type='SimpleEncoderDecoder',
encoder=dict(
type='ResGCAEncoder',
block='BasicBlock',
layers=[3, 4, 4, 2],
in_channels=4,
with_spectral_norm=True),
decoder=dict(
type='ResGCADecoder',
block='BasicBlockDec',
layers=[2, 3, 3, 2],
with_spectral_norm=True)),
loss_alpha=dict(type='L1Loss'),
pretrained='open-mmlab://mmedit/res34_en_nomixup')
train_cfg = dict(train_backbone=True)
test_cfg = dict(metrics=['SAD', 'MSE', 'GRAD', 'CONN'])

# dataset settings
dataset_type = 'AdobeComp1kDataset'
data_root = './data/adobe_composition-1k/'
bg_dir = './data/coco/train2017'
img_norm_cfg = dict(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile', key='alpha', flag='grayscale'),
dict(type='LoadImageFromFile', key='merged'),
dict(
type='CropAroundUnknown',
keys=['alpha', 'merged'],
crop_sizes=[320, 480, 640]),
dict(type='Flip', keys=['alpha', 'merged']),
dict(
type='Resize',
keys=['alpha', 'merged'],
scale=(320, 320),
keep_ratio=False),
dict(type='GenerateTrimap', kernel_size=(1, 30)),
dict(type='RescaleToZeroOne', keys=['merged', 'alpha']),
dict(type='Normalize', keys=['merged'], **img_norm_cfg),
dict(type='Collect', keys=['merged', 'alpha', 'trimap'], meta_keys=[]),
dict(type='ImageToTensor', keys=['merged', 'alpha', 'trimap']),
dict(type='FormatTrimap', to_onehot=False),
]
test_pipeline = [
dict(
type='LoadImageFromFile',
key='alpha',
flag='grayscale',
save_original_img=True),
dict(
type='LoadImageFromFile',
key='trimap',
flag='grayscale',
save_original_img=True),
dict(type='LoadImageFromFile', key='merged'),
dict(type='Pad', keys=['trimap', 'merged'], mode='reflect'),
dict(type='RescaleToZeroOne', keys=['merged']),
dict(type='Normalize', keys=['merged'], **img_norm_cfg),
dict(
type='Collect',
keys=['merged', 'trimap'],
meta_keys=[
'merged_path', 'pad', 'merged_ori_shape', 'ori_alpha', 'ori_trimap'
]),
dict(type='ImageToTensor', keys=['merged', 'trimap']),
dict(type='FormatTrimap', to_onehot=False),
]
data = dict(
samples_per_gpu=10,
workers_per_gpu=4,
val_samples_per_gpu=1,
val_workers_per_gpu=4,
drop_last=True,
train=dict(
type=dataset_type,
ann_file=data_root + 'training_list.json',
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'test_list.json',
data_prefix=data_root,
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'test_list.json',
data_prefix=data_root,
pipeline=test_pipeline))

# optimizer
optimizers = dict(type='Adam', lr=4e-4, betas=[0.5, 0.999])
# learning policy
lr_config = dict(
policy='CosineAnealing',
min_lr=0,
by_epoch=False,
warmup='linear',
warmup_iters=5000,
warmup_ratio=0.001)

# checkpoint saving
checkpoint_config = dict(interval=2000, by_epoch=False)
evaluation = dict(interval=2000, save_image=False, gpu_collect=False)
# yapf:disable
log_config = dict(
interval=10,
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
# dict(type='TensorboardLoggerHook'),
# dict(type='PaviLoggerHook', init_kwargs=dict(project='gca'))
])
# yapf:enable

# runtime settings
total_iters = 200000
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/gca'
load_from = None
resume_from = None
workflow = [('train', 1)]
2 changes: 1 addition & 1 deletion configs/mattors/gca/gca_r34_4x10_200k_comp1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
layers=[2, 3, 3, 2],
with_spectral_norm=True)),
loss_alpha=dict(type='L1Loss'),
pretrained='./weights/model_best_resnet34_En_nomixup_mmedit.pth')
pretrained='open-mmlab://mmedit/res34_en_nomixup')
train_cfg = dict(train_backbone=True)
test_cfg = dict(metrics=['SAD', 'MSE', 'GRAD', 'CONN'])

Expand Down
8 changes: 7 additions & 1 deletion configs/mattors/indexnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
| Method | SAD | MSE | GRAD | CONN | Download |
|:----------:|:-----:|:------:|:-----:|:-----:|:--------:|
| M2O DINs (paper) | 45.8 | 0.013 | 25.9 | **43.7** | - |
| M2O DINs (our) | **45.6** | **0.012** | **25.5** | 44.8 | [model](TODO) \| [log](TODO) |
| M2O DINs (our) | **45.6** | **0.012** | **25.5** | 44.8 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmediting/v0.1/mattors/indexnet/indexnet_mobv2_1x16_78k_comp1k_SAD-45.6_20200618_173817-26dd258d.pth) \| [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmediting/v0.1/mattors/indexnet/indexnet_mobv2_1x16_78k_comp1k_20200618_173817.log.json) |

> The performance of training (best performance) with different random seeds diverges in a large range. You may need to run several experiments for each setting to obtain the above performance.
### More result

| Method | SAD | MSE | GRAD | CONN | Download |
|:----------:|:-----:|:------:|:-----:|:-----:|:--------:|
| M2O DINs (with DIM pipeline) | 50.1 | 0.016 | 30.8 | 49.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmediting/v0.1/mattors/indexnet/TODO_to_be_added) \| [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmediting/v0.1/mattors/indexnet/TODO_to_be_added) |
Loading

0 comments on commit 7e07b08

Please sign in to comment.