Skip to content

Commit

Permalink
[Fix] fix isr, vsr, vfi inferencer (open-mmlab#1845)
Browse files Browse the repository at this point in the history
* [Fix] fix video inferencer

* [Fix] fix video inferencer
  • Loading branch information
Z-Fran authored May 6, 2023
1 parent 5b5f291 commit e248c7c
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 20 deletions.
30 changes: 30 additions & 0 deletions demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ python mmagic_inference_demo.py \
--result-out-dir ../resources/output/restoration/demo_restoration_esrgan_res.png
```

```shell
python mmagic_inference_demo.py \
--model-name ttsr \
--img ../resources/input/restoration/0901x2.png \
--ref ../resources/input/restoration/0901x2.png \
--result-out-dir ../resources/output/restoration/demo_restoration_ttsr_res.png
```

#### 2.2.5 Image translation

```shell
Expand Down Expand Up @@ -167,6 +175,17 @@ python mmagic_inference_demo.py \

#### 2.2.8 Video Super-Resolution

BasicVSR / BasicVSR++ / IconVSR / RealBasicVSR

```shell
python mmagic_inference_demo.py \
--model-name basicvsr \
--video ../resources/input/video_restoration/QUuC4vJs_000084_000094_400x320.mp4 \
--result-out-dir ../resources/output/video_restoration/demo_video_restoration_basicvsr_res.mp4
```

EDVR

```shell
python mmagic_inference_demo.py \
--model-name edvr \
Expand All @@ -175,6 +194,17 @@ python mmagic_inference_demo.py \
--result-out-dir ../resources/output/video_restoration/demo_video_restoration_edvr_res.mp4
```

TDAN

```shell
python mmagic_inference_demo.py \
--model-name tdan \
--model-setting 2
--extra-parameters window_size=5 \
--video ../resources/input/video_restoration/QUuC4vJs_000084_000094_400x320.mp4 \
--result-out-dir ../resources/output/video_restoration/demo_video_restoration_edvr_res.mp4
```

#### 2.2.9 Text-to-Image

```shell
Expand Down
24 changes: 7 additions & 17 deletions mmagic/apis/inferencers/image_super_resolution_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
from mmengine import mkdir_or_exist
from mmengine.dataset import Compose
from mmengine.dataset.utils import default_collate as collate

from mmagic.utils import tensor2img
from .base_mmagic_inferencer import BaseMMagicInferencer, InputsType, PredType
Expand All @@ -34,7 +33,6 @@ def preprocess(self, img: InputsType, ref: InputsType = None) -> Dict:
data(Dict): Results of preprocess.
"""
cfg = self.model.cfg
device = next(self.model.parameters()).device # model device

# select the data pipeline
if cfg.get('inference_pipeline', None):
Expand Down Expand Up @@ -63,31 +61,22 @@ def preprocess(self, img: InputsType, ref: InputsType = None) -> Dict:

# prepare data
if ref: # Ref-SR
data = dict(img_path=img, gt_path=ref)
data = dict(img_path=img, ref_path=ref)
else: # SISR
data = dict(img_path=img)
_data = test_pipeline(data)

data = dict()
data_preprocessor = cfg['model']['data_preprocessor']
mean = torch.Tensor(data_preprocessor['mean']).view([3, 1, 1])
std = torch.Tensor(data_preprocessor['std']).view([3, 1, 1])
data['inputs'] = (_data['inputs'] - mean) / std
data = collate([data])

if ref:
data['data_samples'] = [_data['data_samples']]
if 'cuda' in str(device):
data['inputs'] = data['inputs'].cuda()
if ref:
data['data_samples'][0] = data['data_samples'][0].cuda()
data['inputs'] = [_data['inputs']]
data['data_samples'] = [_data['data_samples']]

return data

def forward(self, inputs: InputsType) -> PredType:
"""Forward the inputs to the model."""
inputs = self.model.data_preprocessor(inputs)
with torch.no_grad():
result = self.model(mode='tensor', **inputs)
result = self.model(mode='predict', **inputs)
return result

def visualize(self,
Expand All @@ -105,7 +94,8 @@ def visualize(self,
Returns:
List[np.ndarray]: Result of visualize
"""
results = tensor2img(preds[0])
result = preds[0].output.pred_img / 255.
results = tensor2img(result)[..., ::-1]
if result_out_dir:
mkdir_or_exist(os.path.dirname(result_out_dir))
mmcv.imwrite(results, result_out_dir)
Expand Down
7 changes: 5 additions & 2 deletions mmagic/apis/inferencers/video_interpolation_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import cv2
import mmcv
import mmengine
import numpy as np
import torch
from mmengine.dataset import Compose
Expand Down Expand Up @@ -109,6 +110,7 @@ def forward(self,

# check if the output is a video
output_file_extension = os.path.splitext(result_out_dir)[1]
mmengine.utils.mkdir_or_exist(osp.dirname(result_out_dir))
if output_file_extension in VIDEO_EXTENSIONS:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
target = cv2.VideoWriter(result_out_dir, fourcc, output_fps,
Expand Down Expand Up @@ -179,11 +181,12 @@ def forward(self,
self.extra_parameters['end_idx']:
break

logger: MMLogger = MMLogger.get_current_instance()
logger.info(f'Output video is save at {result_out_dir}.')
if to_video:
target.release()

logger: MMLogger = MMLogger.get_current_instance()
logger.info(f'Output video is save at {result_out_dir}.')

return {}

def visualize(self,
Expand Down
7 changes: 6 additions & 1 deletion mmagic/apis/inferencers/video_restoration_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

import cv2
import mmcv
import mmengine
import numpy as np
import torch
from mmengine.dataset import Compose
from mmengine.logging import MMLogger
from mmengine.utils import ProgressBar

from mmagic.utils import tensor2img
from .base_mmagic_inferencer import (BaseMMagicInferencer, InputsType,
Expand Down Expand Up @@ -153,13 +155,16 @@ def visualize(self,
List[np.ndarray]: Result of visualize
"""
file_extension = os.path.splitext(result_out_dir)[1]
mmengine.utils.mkdir_or_exist(osp.dirname(result_out_dir))
prog_bar = ProgressBar(preds.size(1))
if file_extension in VIDEO_EXTENSIONS: # save as video
h, w = preds.shape[-2:]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(result_out_dir, fourcc, 25, (w, h))
for i in range(0, preds.size(1)):
img = tensor2img(preds[:, i, :, :, :])
video_writer.write(img.astype(np.uint8))
prog_bar.update()
cv2.destroyAllWindows()
video_writer.release()
else:
Expand All @@ -170,8 +175,8 @@ def visualize(self,
output_i = tensor2img(output_i)
filename_tmpl = self.extra_parameters['filename_tmpl']
save_path_i = f'{result_out_dir}/{filename_tmpl.format(i)}'

mmcv.imwrite(output_i, save_path_i)
prog_bar.update()

logger: MMLogger = MMLogger.get_current_instance()
logger.info(f'Output video is save at {result_out_dir}.')
Expand Down

0 comments on commit e248c7c

Please sign in to comment.