Skip to content

Commit bfc0dcc

Browse files
committed
Improve image extension handling, add methods to modify / get defaults. Fix #1335 fix #1274.
1 parent 7d4b380 commit bfc0dcc

File tree

8 files changed

+103
-23
lines changed

8 files changed

+103
-23
lines changed

timm/data/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from .dataset_factory import create_dataset
77
from .loader import create_loader
88
from .mixup import Mixup, FastCollateMixup
9-
from .parsers import create_parser
9+
from .parsers import create_parser,\
10+
get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
1011
from .real_labels import RealLabelsImagenet
1112
from .transforms import *
12-
from .transforms_factory import create_transform
13+
from .transforms_factory import create_transform

timm/data/parsers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .parser_factory import create_parser
2+
from .img_extensions import *

timm/data/parsers/constants.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

timm/data/parsers/img_extensions.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from copy import deepcopy
2+
3+
__all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions']
4+
5+
6+
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use
7+
_IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync
8+
9+
10+
def _set_extensions(extensions):
11+
global IMG_EXTENSIONS
12+
global _IMG_EXTENSIONS_SET
13+
dedupe = set() # NOTE de-duping tuple while keeping original order
14+
IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x))
15+
_IMG_EXTENSIONS_SET = set(extensions)
16+
17+
18+
def _valid_extension(x: str):
19+
return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.')
20+
21+
22+
def is_img_extension(ext):
23+
return ext in _IMG_EXTENSIONS_SET
24+
25+
26+
def get_img_extensions(as_set=False):
27+
return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS)
28+
29+
30+
def set_img_extensions(extensions):
31+
assert len(extensions)
32+
for x in extensions:
33+
assert _valid_extension(x)
34+
_set_extensions(extensions)
35+
36+
37+
def add_img_extensions(ext):
38+
if not isinstance(ext, (list, tuple, set)):
39+
ext = (ext,)
40+
for x in ext:
41+
assert _valid_extension(x)
42+
extensions = IMG_EXTENSIONS + tuple(ext)
43+
_set_extensions(extensions)
44+
45+
46+
def del_img_extensions(ext):
47+
if not isinstance(ext, (list, tuple, set)):
48+
ext = (ext,)
49+
extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext)
50+
_set_extensions(extensions)

timm/data/parsers/parser_factory.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22

33
from .parser_image_folder import ParserImageFolder
4-
from .parser_image_tar import ParserImageTar
54
from .parser_image_in_tar import ParserImageInTar
65

76

timm/data/parsers/parser_image_folder.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,35 @@
66
Hacked together by / Copyright 2020 Ross Wightman
77
"""
88
import os
9+
from typing import Dict, List, Optional, Set, Tuple, Union
910

1011
from timm.utils.misc import natural_key
1112

12-
from .parser import Parser
1313
from .class_map import load_class_map
14-
from .constants import IMG_EXTENSIONS
14+
from .img_extensions import get_img_extensions
15+
from .parser import Parser
16+
17+
18+
def find_images_and_targets(
19+
folder: str,
20+
types: Optional[Union[List, Tuple, Set]] = None,
21+
class_to_idx: Optional[Dict] = None,
22+
leaf_name_only: bool = True,
23+
sort: bool = True
24+
):
25+
""" Walk folder recursively to discover images and map them to classes by folder names.
1526
27+
Args:
28+
folder: root of folder to recrusively search
29+
types: types (file extensions) to search for in path
30+
class_to_idx: specify mapping for class (folder name) to class index if set
31+
leaf_name_only: use only leaf-name of folder walk for class names
32+
sort: re-sort found images by name (for consistent ordering)
1633
17-
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
34+
Returns:
35+
A list of image and target tuples, class_to_idx mapping
36+
"""
37+
types = get_img_extensions(as_set=True) if not types else set(types)
1838
labels = []
1939
filenames = []
2040
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
@@ -51,7 +71,8 @@ def __init__(
5171
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
5272
if len(self.samples) == 0:
5373
raise RuntimeError(
54-
f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
74+
f'Found 0 images in subfolders of {root}. '
75+
f'Supported image extensions are {", ".join(get_img_extensions())}')
5576

5677
def __getitem__(self, index):
5778
path, target = self.samples[index]

timm/data/parsers/parser_image_in_tar.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@
99
1010
Hacked together by / Copyright 2020 Ross Wightman
1111
"""
12+
import logging
1213
import os
13-
import tarfile
1414
import pickle
15-
import logging
16-
import numpy as np
15+
import tarfile
1716
from glob import glob
18-
from typing import List, Dict
17+
from typing import List, Tuple, Dict, Set, Optional, Union
18+
19+
import numpy as np
1920

2021
from timm.utils.misc import natural_key
2122

22-
from .parser import Parser
2323
from .class_map import load_class_map
24-
from .constants import IMG_EXTENSIONS
25-
24+
from .img_extensions import get_img_extensions
25+
from .parser import Parser
2626

2727
_logger = logging.getLogger(__name__)
2828
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
@@ -39,7 +39,7 @@ def reset(self):
3939
self.tf = None
4040

4141

42-
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS):
42+
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions: Set[str]):
4343
sample_count = 0
4444
for i, ti in enumerate(tf):
4545
if not ti.isfile():
@@ -60,7 +60,14 @@ def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTE
6060
return sample_count
6161

6262

63-
def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True):
63+
def extract_tarinfos(
64+
root,
65+
class_name_to_idx: Optional[Dict] = None,
66+
cache_tarinfo: Optional[bool] = None,
67+
extensions: Optional[Union[List, Tuple, Set]] = None,
68+
sort: bool = True
69+
):
70+
extensions = get_img_extensions(as_set=True) if not extensions else set(extensions)
6471
root_is_tar = False
6572
if os.path.isfile(root):
6673
assert os.path.splitext(root)[-1].lower() == '.tar'
@@ -176,8 +183,8 @@ def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):
176183
self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
177184
self.root,
178185
class_name_to_idx=class_name_to_idx,
179-
cache_tarinfo=cache_tarinfo,
180-
extensions=IMG_EXTENSIONS)
186+
cache_tarinfo=cache_tarinfo
187+
)
181188
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
182189
if len(tarfiles) == 1 and tarfiles[0][0] is None:
183190
self.root_is_tar = True

timm/data/parsers/parser_image_tar.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
import os
99
import tarfile
1010

11-
from .parser import Parser
12-
from .class_map import load_class_map
13-
from .constants import IMG_EXTENSIONS
1411
from timm.utils.misc import natural_key
1512

13+
from .class_map import load_class_map
14+
from .img_extensions import get_img_extensions
15+
from .parser import Parser
16+
1617

1718
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
19+
extensions = get_img_extensions(as_set=True)
1820
files = []
1921
labels = []
2022
for ti in tarfile.getmembers():
@@ -23,7 +25,7 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
2325
dirname, basename = os.path.split(ti.path)
2426
label = os.path.basename(dirname)
2527
ext = os.path.splitext(basename)[1]
26-
if ext.lower() in IMG_EXTENSIONS:
28+
if ext.lower() in extensions:
2729
files.append(ti)
2830
labels.append(label)
2931
if class_to_idx is None:

0 commit comments

Comments
 (0)