diff --git a/mmedit/datasets/pipelines/loading.py b/mmedit/datasets/pipelines/loading.py index 395053066f..c7f162732c 100644 --- a/mmedit/datasets/pipelines/loading.py +++ b/mmedit/datasets/pipelines/loading.py @@ -50,7 +50,7 @@ def __init__(self, self.kwargs = kwargs self.file_client = None self.use_cache = use_cache - self.cache = None + self.cache = dict() if use_cache else None self.backend = backend def __call__(self, results): @@ -67,8 +67,6 @@ def __call__(self, results): if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.kwargs) if self.use_cache: - if self.cache is None: - self.cache = dict() if filepath in self.cache: img = self.cache[filepath] else: @@ -132,6 +130,9 @@ class LoadImageFromFileList(LoadImageFromFile): no conversion is conducted. Default: None. save_original_img (bool): If True, maintain a copy of the image in `results` dict with name of `f'ori_{key}'`. Default: False. + use_cache (bool): If True, load all images at once. Default: False. + backend (str): The image loading backend type. Options are `cv2`, + `pillow`, and 'turbojpeg'. Default: None. kwargs (dict): Args for file client. """ @@ -160,10 +161,24 @@ def __call__(self, results): if self.save_original_img: ori_imgs = [] for filepath in filepaths: - img_bytes = self.file_client.get(filepath) - img = mmcv.imfrombytes( - img_bytes, flag=self.flag, - channel_order=self.channel_order) # HWC + if self.use_cache: + if filepath in self.cache: + img = self.cache[filepath] + else: + img_bytes = self.file_client.get(filepath) + img = mmcv.imfrombytes( + img_bytes, + flag=self.flag, + channel_order=self.channel_order, + backend=self.backend) # HWC + self.cache[filepath] = img + else: + img_bytes = self.file_client.get(filepath) + img = mmcv.imfrombytes( + img_bytes, + flag=self.flag, + channel_order=self.channel_order, + backend=self.backend) # HWC # convert to y-channel, if specified if self.convert_to is not None: diff --git a/tests/test_data/test_pipelines/test_loading.py b/tests/test_data/test_pipelines/test_loading.py index 07c6d5cb8f..1280889baa 100644 --- a/tests/test_data/test_pipelines/test_loading.py +++ b/tests/test_data/test_pipelines/test_loading.py @@ -70,16 +70,22 @@ def test_load_image_from_file(): assert id(results['ori_lq']) != id(results['lq']) # test: use_cache - results = dict(gt_path=path_baboon) + results_ori = dict(gt_path=path_baboon) config = dict(io_backend='disk', key='gt', use_cache=True) image_loader = LoadImageFromFile(**config) - assert image_loader.cache is None assert repr(image_loader) == ( image_loader.__class__.__name__ + ('(io_backend=disk, key=gt, ' 'flag=color, save_original_img=False, channel_order=bgr, ' 'use_cache=True)')) - results = image_loader(results) + assert not image_loader.cache + results = image_loader(results_ori) + assert image_loader.cache is not None + assert str(path_baboon) in image_loader.cache + assert results['gt'].shape == (480, 500, 3) + assert results['gt_path'] == str(path_baboon) + np.testing.assert_almost_equal(results['gt'], img_baboon) + results = image_loader(results_ori) assert image_loader.cache is not None assert str(path_baboon) in image_loader.cache assert results['gt'].shape == (480, 500, 3) @@ -188,6 +194,28 @@ def test_load_image_from_file_list(): with pytest.raises(ValueError): results = image_loader(results) + # convert to use_cache + results_ori = dict(gt_path=[str(path_baboon_x4), str(path_baboon)]) + config = dict(io_backend='disk', key='gt', use_cache=True) + image_loader = LoadImageFromFileList(**config) + assert not image_loader.cache + assert repr(image_loader) == ( + image_loader.__class__.__name__ + + ('(io_backend=disk, key=gt, ' + 'flag=color, save_original_img=False, channel_order=bgr, ' + 'use_cache=True)')) + results = image_loader(results_ori) + assert str(path_baboon) in image_loader.cache + assert len(results['gt']) == 2 + assert results['gt'][1].shape == (480, 500, 3) + assert results['gt_path'] == [str(path_baboon_x4), str(path_baboon)] + results = image_loader(results_ori) + assert image_loader.cache is not None + assert str(path_baboon) in image_loader.cache + assert len(results['gt']) == 2 + assert results['gt'][1].shape == (480, 500, 3) + assert results['gt_path'] == [str(path_baboon_x4), str(path_baboon)] + class TestMattingLoading: