Skip to content

Commit

Permalink
Add param_name to size_dict logs & tidy (huggingface#20205)
Browse files Browse the repository at this point in the history
  • Loading branch information
amyeroberts authored Nov 15, 2022
1 parent f1e8c48 commit 55ba319
Show file tree
Hide file tree
Showing 12 changed files with 99 additions and 63 deletions.
84 changes: 51 additions & 33 deletions src/transformers/image_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,50 @@ def preprocess(self, images, **kwargs) -> BatchFeature:
raise NotImplementedError("Each image processor must implement its own preprocess method")


VALID_SIZE_DICT_KEYS = ({"height", "width"}, {"shortest_edge"}, {"shortest_edge", "longest_edge"})


def is_valid_size_dict(size_dict):
if not isinstance(size_dict, dict):
return False

size_dict_keys = set(size_dict.keys())
for allowed_keys in VALID_SIZE_DICT_KEYS:
if size_dict_keys == allowed_keys:
return True
return False


def convert_to_size_dict(
size, max_size: Optional[int] = None, default_to_square: bool = True, height_width_order: bool = True
):
# By default, if size is an int we assume it represents a tuple of (size, size).
if isinstance(size, int) and default_to_square:
if max_size is not None:
raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
return {"height": size, "width": size}
# In other configs, if size is an int and default_to_square is False, size represents the length of
# the shortest edge after resizing.
elif isinstance(size, int) and not default_to_square:
size_dict = {"shortest_edge": size}
if max_size is not None:
size_dict["longest_edge"] = max_size
return size_dict
# Otherwise, if size is a tuple it's either (height, width) or (width, height)
elif isinstance(size, (tuple, list)) and height_width_order:
return {"height": size[0], "width": size[1]}
elif isinstance(size, (tuple, list)) and not height_width_order:
return {"height": size[1], "width": size[0]}

raise ValueError(f"Could not convert size input to size dict: {size}")


def get_size_dict(
size: Union[int, Iterable[int], Dict[str, int]] = None,
max_size: Optional[int] = None,
height_width_order: bool = True,
default_to_square: bool = True,
param_name="size",
) -> dict:
"""
Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
Expand All @@ -467,40 +506,19 @@ def get_size_dict(
default_to_square (`bool`, *optional*, defaults to `True`):
If `size` is an int, whether to default to a square image or not.
"""
# If a dict is passed, we check if it's a valid size dict and then return it.
if isinstance(size, dict):
size_keys = set(size.keys())
if (
size_keys != set(["height", "width"])
and size_keys != set(["shortest_edge"])
and size_keys != set(["shortest_edge", "longest_edge"])
):
raise ValueError(
"The size dict must contain either the keys ('height', 'width') or ('shortest_edge')"
f"or ('shortest_edge', 'longest_edge') but got {size_keys}"
)
return size

# By default, if size is an int we assume it represents a tuple of (size, size).
elif isinstance(size, int) and default_to_square:
if max_size is not None:
raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
size_dict = {"height": size, "width": size}
# In other configs, if size is an int and default_to_square is False, size represents the length of the shortest edge after resizing.
elif isinstance(size, int) and not default_to_square:
if max_size is not None:
size_dict = {"shortest_edge": size, "longest_edge": max_size}
else:
size_dict = {"shortest_edge": size}
elif isinstance(size, (tuple, list)) and height_width_order:
size_dict = {"height": size[0], "width": size[1]}
elif isinstance(size, (tuple, list)) and not height_width_order:
size_dict = {"height": size[1], "width": size[0]}
if not isinstance(size, dict):
size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order)
logger.info(
"{param_name} should be a dictionary on of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size}."
" Converted to {size_dict}.",
)
else:
size_dict = size

logger.info(
"The size parameter should be a dictionary with keys ('height', 'width'), ('shortest_edge', 'longest_edge')"
f" or ('shortest_edge',) got {size}. Setting as {size_dict}.",
)
if not is_valid_size_dict(size_dict):
raise ValueError(
f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}"
)
return size_dict


Expand Down
10 changes: 5 additions & 5 deletions src/transformers/models/beit/image_processing_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(
size = size if size is not None else {"height": 256, "width": 256}
size = get_size_dict(size)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize
self.size = size
self.resample = resample
Expand Down Expand Up @@ -152,7 +152,7 @@ def resize(
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
size = get_size_dict(size, default_to_square=True, param_name="size")
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` argument must contain `height` and `width` keys. Got {size.keys()}")
return resize(
Expand All @@ -178,7 +178,7 @@ def center_crop(
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
size = get_size_dict(size, default_to_square=True, param_name="size")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)

def rescale(
Expand Down Expand Up @@ -406,11 +406,11 @@ def preprocess(
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size)
size = get_size_dict(size, default_to_square=True, param_name="size")
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
Expand Down
8 changes: 5 additions & 3 deletions src/transformers/models/clip/image_processing_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")

self.do_resize = do_resize
self.size = size
Expand Down Expand Up @@ -176,6 +176,8 @@ def center_crop(
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` parameter must contain the keys (height, width). Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)

def rescale(
Expand Down Expand Up @@ -285,11 +287,11 @@ def preprocess(
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
size = get_size_dict(size, param_name="size", default_to_square=False)
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/deit/image_processing_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
size = size if size is not None else {"height": 256, "width": 256}
size = get_size_dict(size)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")

self.do_resize = do_resize
self.size = size
Expand Down Expand Up @@ -158,6 +158,8 @@ def center_crop(
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)

def rescale(
Expand Down Expand Up @@ -272,7 +274,7 @@ def preprocess(
size = size if size is not None else self.size
size = get_size_dict(size)
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")

if not is_batched(images):
images = [images]
Expand Down
14 changes: 8 additions & 6 deletions src/transformers/models/flava/image_processing_flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,12 @@ def __init__(
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")

codebook_size = codebook_size if codebook_size is not None else {"height": 112, "width": 112}
codebook_size = get_size_dict(codebook_size)
codebook_size = get_size_dict(codebook_size, param_name="codebook_size")
codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else {"height": 112, "width": 112}
codebook_crop_size = get_size_dict(codebook_crop_size)
codebook_crop_size = get_size_dict(codebook_crop_size, param_name="codebook_crop_size")

self.do_resize = do_resize
self.size = size
Expand Down Expand Up @@ -360,6 +360,8 @@ def center_crop(
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must contain 'height' and 'width' keys. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)

def rescale(
Expand Down Expand Up @@ -580,7 +582,7 @@ def preprocess(
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
Expand Down Expand Up @@ -612,7 +614,7 @@ def preprocess(
)
codebook_do_resize = codebook_do_resize if codebook_do_resize is not None else self.codebook_do_resize
codebook_size = codebook_size if codebook_size is not None else self.codebook_size
codebook_size = get_size_dict(codebook_size)
codebook_size = get_size_dict(codebook_size, param_name="codebook_size")
codebook_resample = codebook_resample if codebook_resample is not None else self.codebook_resample
codebook_do_rescale = codebook_do_rescale if codebook_do_rescale is not None else self.codebook_do_rescale
codebook_rescale_factor = (
Expand All @@ -622,7 +624,7 @@ def preprocess(
codebook_do_center_crop if codebook_do_center_crop is not None else self.codebook_do_center_crop
)
codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else self.codebook_crop_size
codebook_crop_size = get_size_dict(codebook_crop_size)
codebook_crop_size = get_size_dict(codebook_crop_size, param_name="codebook_crop_size")
codebook_do_map_pixels = (
codebook_do_map_pixels if codebook_do_map_pixels is not None else self.codebook_do_map_pixels
)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/levit/image_processing_levit.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(
size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")

self.do_resize = do_resize
self.size = size
Expand Down Expand Up @@ -182,6 +182,8 @@ def center_crop(
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"Size dict must have keys 'height' and 'width'. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)

def rescale(
Expand Down Expand Up @@ -299,7 +301,7 @@ def preprocess(
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")

if not is_batched(images):
images = [images]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
size = size if size is not None else {"shortest_edge": 256}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize
self.size = size
self.resample = resample
Expand Down Expand Up @@ -169,6 +169,8 @@ def center_crop(
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` parameter must contain the keys `height` and `width`. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)

def rescale(
Expand Down Expand Up @@ -286,7 +288,7 @@ def preprocess(
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")

self.do_resize = do_resize
self.size = size
Expand Down Expand Up @@ -182,6 +182,8 @@ def center_crop(
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)

def rescale(
Expand Down Expand Up @@ -280,7 +282,7 @@ def preprocess(
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")

if not is_batched(images):
images = [images]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
) -> None:
super().__init__(**kwargs)
crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size)

Expand Down Expand Up @@ -141,7 +141,7 @@ def center_crop(
"""
size = self.size if size is None else size
size = get_size_dict(size)
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")

height, width = get_image_size(image)
min_dim = min(height, width)
Expand Down Expand Up @@ -278,7 +278,7 @@ def preprocess(
"""
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size")
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size)
Expand Down
Loading

0 comments on commit 55ba319

Please sign in to comment.