Skip to content

Commit

Permalink
[Feature] Support CPU demo (open-mmlab#848)
Browse files Browse the repository at this point in the history
* [Fix] Fix args

* [Feature] Support CPU demo

* Update
  • Loading branch information
Yshuo-Li authored Apr 19, 2022
1 parent e7deaa9 commit 998109a
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 18 deletions.
8 changes: 6 additions & 2 deletions demo/generation_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ def parse_args():
def main():
args = parse_args()

model = init_model(
args.config, args.checkpoint, device=torch.device('cuda', args.device))
if args.device < 0 or not torch.cuda.is_available():
device = torch.device('cpu')
else:
device = torch.device('cuda', args.device)

model = init_model(args.config, args.checkpoint, device=device)

output = generation_inference(model, args.img_path, args.unpaired_path)

Expand Down
8 changes: 6 additions & 2 deletions demo/inpainting_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ def parse_args():
def main():
args = parse_args()

model = init_model(
args.config, args.checkpoint, device=torch.device('cuda', args.device))
if args.device < 0 or not torch.cuda.is_available():
device = torch.device('cpu')
else:
device = torch.device('cuda', args.device)

model = init_model(args.config, args.checkpoint, device=device)

result = inpainting_inference(model, args.masked_img_path, args.mask_path)
result = tensor2img(result, min_max=(-1, 1))[..., ::-1]
Expand Down
8 changes: 6 additions & 2 deletions demo/matting_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@ def parse_args():
def main():
args = parse_args()

model = init_model(
args.config, args.checkpoint, device=torch.device('cuda', args.device))
if args.device < 0 or not torch.cuda.is_available():
device = torch.device('cpu')
else:
device = torch.device('cuda', args.device)

model = init_model(args.config, args.checkpoint, device=device)

pred_alpha = matting_inference(model, args.img_path,
args.trimap_path) * 255
Expand Down
8 changes: 6 additions & 2 deletions demo/restoration_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,12 @@ def main():
'you may want to use "ref_path=None" '
'for single restoration.')

model = init_model(
args.config, args.checkpoint, device=torch.device('cuda', args.device))
if args.device < 0 or not torch.cuda.is_available():
device = torch.device('cpu')
else:
device = torch.device('cuda', args.device)

model = init_model(args.config, args.checkpoint, device=device)

if args.ref_path: # Ref-SR
output = restoration_inference(model, args.img_path, args.ref_path)
Expand Down
8 changes: 6 additions & 2 deletions demo/restoration_face_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@ def main():
'you may want to use "restoration_video_demo.py" '
'for video restoration.')

model = init_model(
args.config, args.checkpoint, device=torch.device('cuda', args.device))
if args.device < 0 or not torch.cuda.is_available():
device = torch.device('cpu')
else:
device = torch.device('cuda', args.device)

model = init_model(args.config, args.checkpoint, device=device)

output = restoration_face_inference(model, args.img_path,
args.upscale_factor, args.face_size)
Expand Down
8 changes: 6 additions & 2 deletions demo/restoration_video_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,12 @@ def main():

args = parse_args()

model = init_model(
args.config, args.checkpoint, device=torch.device('cuda', args.device))
if args.device < 0 or not torch.cuda.is_available():
device = torch.device('cpu')
else:
device = torch.device('cuda', args.device)

model = init_model(args.config, args.checkpoint, device=device)

output = restoration_video_inference(model, args.input_dir,
args.window_size, args.start_idx,
Expand Down
3 changes: 2 additions & 1 deletion demo/video_interpolation_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,11 @@ def main():

args = parse_args()

if args.device < 0:
if args.device < 0 or not torch.cuda.is_available():
device = torch.device('cpu')
else:
device = torch.device('cuda', args.device)

model = init_model(args.config, args.checkpoint, device=device)

video_interpolation_inference(
Expand Down
4 changes: 3 additions & 1 deletion mmedit/apis/generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def generation_inference(model, img, img_unpaired=None):
else:
data = dict(img_a_path=img, img_b_path=img_unpaired)
data = test_pipeline(data)
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
data = collate([data], samples_per_gpu=1)
if 'cuda' in str(device):
data = scatter(data, [device])[0]
# forward the model
with torch.no_grad():
results = model(test_mode=True, **data)
Expand Down
6 changes: 5 additions & 1 deletion mmedit/apis/inpainting_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ def inpainting_inference(model, masked_img, mask):
# prepare data
data = dict(masked_img_path=masked_img, mask_path=mask)
data = test_pipeline(data)
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
data = collate([data], samples_per_gpu=1)
if 'cuda' in str(device):
data = scatter(data, [device])[0]
else:
data.pop('meta')
# forward the model
with torch.no_grad():
result = model(test_mode=True, **data)
Expand Down
4 changes: 3 additions & 1 deletion mmedit/apis/matting_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def matting_inference(model, img, trimap):
# prepare data
data = dict(merged_path=img, trimap_path=trimap)
data = test_pipeline(data)
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
data = collate([data], samples_per_gpu=1)
if 'cuda' in str(device):
data = scatter(data, [device])[0]
# forward the model
with torch.no_grad():
result = model(test_mode=True, **data)
Expand Down
4 changes: 3 additions & 1 deletion mmedit/apis/restoration_face_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def restoration_face_inference(model, img, upscale_factor=1, face_size=1024):
# prepare data
data = dict(lq=img.astype(np.float32))
data = test_pipeline(data)
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
data = collate([data], samples_per_gpu=1)
if 'cuda' in str(device):
data = scatter(data, [device])[0]

with torch.no_grad():
output = model(test_mode=True, **data)['output'].clip_(0, 1)
Expand Down
4 changes: 3 additions & 1 deletion mmedit/apis/restoration_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def restoration_inference(model, img, ref=None):
else: # SISR
data = dict(lq_path=img)
data = test_pipeline(data)
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
data = collate([data], samples_per_gpu=1)
if 'cuda' in str(device):
data = scatter(data, [device])[0]
# forward the model
with torch.no_grad():
result = model(test_mode=True, **data)
Expand Down

0 comments on commit 998109a

Please sign in to comment.