Skip to content

Commit

Permalink
[Feature] Support use_cache and backend in LoadImageFromFileList (o…
Browse files Browse the repository at this point in the history
…pen-mmlab#857)

* [Feature] Support use_cache and backend in LoadImageFromFileList

* Update
  • Loading branch information
Yshuo-Li authored May 12, 2022
1 parent 9a01375 commit 3351cdc
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 10 deletions.
29 changes: 22 additions & 7 deletions mmedit/datasets/pipelines/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self,
self.kwargs = kwargs
self.file_client = None
self.use_cache = use_cache
self.cache = None
self.cache = dict() if use_cache else None
self.backend = backend

def __call__(self, results):
Expand All @@ -67,8 +67,6 @@ def __call__(self, results):
if self.file_client is None:
self.file_client = FileClient(self.io_backend, **self.kwargs)
if self.use_cache:
if self.cache is None:
self.cache = dict()
if filepath in self.cache:
img = self.cache[filepath]
else:
Expand Down Expand Up @@ -132,6 +130,9 @@ class LoadImageFromFileList(LoadImageFromFile):
no conversion is conducted. Default: None.
save_original_img (bool): If True, maintain a copy of the image in
`results` dict with name of `f'ori_{key}'`. Default: False.
use_cache (bool): If True, load all images at once. Default: False.
backend (str): The image loading backend type. Options are `cv2`,
`pillow`, and 'turbojpeg'. Default: None.
kwargs (dict): Args for file client.
"""

Expand Down Expand Up @@ -160,10 +161,24 @@ def __call__(self, results):
if self.save_original_img:
ori_imgs = []
for filepath in filepaths:
img_bytes = self.file_client.get(filepath)
img = mmcv.imfrombytes(
img_bytes, flag=self.flag,
channel_order=self.channel_order) # HWC
if self.use_cache:
if filepath in self.cache:
img = self.cache[filepath]
else:
img_bytes = self.file_client.get(filepath)
img = mmcv.imfrombytes(
img_bytes,
flag=self.flag,
channel_order=self.channel_order,
backend=self.backend) # HWC
self.cache[filepath] = img
else:
img_bytes = self.file_client.get(filepath)
img = mmcv.imfrombytes(
img_bytes,
flag=self.flag,
channel_order=self.channel_order,
backend=self.backend) # HWC

# convert to y-channel, if specified
if self.convert_to is not None:
Expand Down
34 changes: 31 additions & 3 deletions tests/test_data/test_pipelines/test_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,22 @@ def test_load_image_from_file():
assert id(results['ori_lq']) != id(results['lq'])

# test: use_cache
results = dict(gt_path=path_baboon)
results_ori = dict(gt_path=path_baboon)
config = dict(io_backend='disk', key='gt', use_cache=True)
image_loader = LoadImageFromFile(**config)
assert image_loader.cache is None
assert repr(image_loader) == (
image_loader.__class__.__name__ +
('(io_backend=disk, key=gt, '
'flag=color, save_original_img=False, channel_order=bgr, '
'use_cache=True)'))
results = image_loader(results)
assert not image_loader.cache
results = image_loader(results_ori)
assert image_loader.cache is not None
assert str(path_baboon) in image_loader.cache
assert results['gt'].shape == (480, 500, 3)
assert results['gt_path'] == str(path_baboon)
np.testing.assert_almost_equal(results['gt'], img_baboon)
results = image_loader(results_ori)
assert image_loader.cache is not None
assert str(path_baboon) in image_loader.cache
assert results['gt'].shape == (480, 500, 3)
Expand Down Expand Up @@ -188,6 +194,28 @@ def test_load_image_from_file_list():
with pytest.raises(ValueError):
results = image_loader(results)

# convert to use_cache
results_ori = dict(gt_path=[str(path_baboon_x4), str(path_baboon)])
config = dict(io_backend='disk', key='gt', use_cache=True)
image_loader = LoadImageFromFileList(**config)
assert not image_loader.cache
assert repr(image_loader) == (
image_loader.__class__.__name__ +
('(io_backend=disk, key=gt, '
'flag=color, save_original_img=False, channel_order=bgr, '
'use_cache=True)'))
results = image_loader(results_ori)
assert str(path_baboon) in image_loader.cache
assert len(results['gt']) == 2
assert results['gt'][1].shape == (480, 500, 3)
assert results['gt_path'] == [str(path_baboon_x4), str(path_baboon)]
results = image_loader(results_ori)
assert image_loader.cache is not None
assert str(path_baboon) in image_loader.cache
assert len(results['gt']) == 2
assert results['gt'][1].shape == (480, 500, 3)
assert results['gt_path'] == [str(path_baboon_x4), str(path_baboon)]


class TestMattingLoading:

Expand Down

0 comments on commit 3351cdc

Please sign in to comment.