Skip to content

Commit

Permalink
[Fix] guided diffusion inference demo code (open-mmlab#1862)
Browse files Browse the repository at this point in the history
* [Fix] demo code for guided-diffusion

* [Fix] typos
  • Loading branch information
SheffieldCao authored May 23, 2023
1 parent 97f8a9b commit d2429b5
Showing 1 changed file with 26 additions and 16 deletions.
42 changes: 26 additions & 16 deletions configs/guided_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,34 +47,44 @@ We show that diffusion models can achieve image sample quality superior to the c
You can run adm as follows:

```python
from mmagic.utils import register_all_modules
from mmagic.apis import init_model
from mmengine import Config, MODELS
from mmengine.registry import init_default_scope
from torchvision.utils import save_image

register_all_modules()

# sampling without classifier guidance
config = 'configs/guided_diffusion/adm_ddim250_8xb32_imagenet-64x64.py'
ckpt_path = 'https://download.openmmlab.com/mmediting/guided_diffusion/adm-u-cvt-rgb_8xb32_imagenet-64x64-7ff0080b.pth' # noqa
model = init_model(config, ckpt_path)
samples = model.infer(
init_image=None,
batch_size=4,
num_inference_steps=250,
labels=None,
classifier_scale=0.0,
show_progress=True)['samples']
init_default_scope('mmagic')

# sampling without classifier guidance, CGS=1.0
config = 'configs/guided_diffusion/adm-g_ddim25_8xb32_imagenet-64x64.py'
ckpt_path = 'https://download.openmmlab.com/mmediting/guided_diffusion/adm-g_8xb32_imagenet-64x64-2c0fbeda.pth' # noqa
model = init_model(config, ckpt_path)

model_cfg = Config.fromfile(config).model
model_cfg.pretrained_cfgs = dict(unet=dict(ckpt_path=ckpt_path, prefix='unet'),
classifier=dict(ckpt_path=ckpt_path, prefix='classifier'))
model = MODELS.build(model_cfg).cuda().eval()

samples = model.infer(
init_image=None,
batch_size=4,
num_inference_steps=25,
labels=333,
classifier_scale=1.0,
show_progress=True)['samples']

# sampling without classifier guidance
config = 'configs/guided_diffusion/adm_ddim250_8xb32_imagenet-64x64.py'
ckpt_path = 'https://download.openmmlab.com/mmediting/guided_diffusion/adm-u-cvt-rgb_8xb32_imagenet-64x64-7ff0080b.pth' # noqa

model_cfg = Config.fromfile(config).model
model_cfg.pretrained_cfgs = dict(unet=dict(ckpt_path=ckpt_path, prefix='unet'))
model = MODELS.build(model_cfg).cuda().eval()

samples = model.infer(
init_image=None,
batch_size=4,
num_inference_steps=250,
labels=None,
classifier_scale=0.0,
show_progress=True)['samples']
```

**Test**
Expand Down

0 comments on commit d2429b5

Please sign in to comment.